Skip to content

Commit

Permalink
Make from_dict more flexible, and add from_pytree (#2291)
Browse files Browse the repository at this point in the history
* Make `from_dict` more flexible, and add `from_pytree`

* restructure docs

* fix ipython directive

* make pytree_to_dataset available at top level

---------

Co-authored-by: Oriol (ProDesk) <[email protected]>
  • Loading branch information
ColCarroll and OriolAbril authored Mar 14, 2024
1 parent 6f50066 commit 29ca5a1
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## v0.x.x Unreleased

### New features
- Support for `pytree`s and robust to nested dictionaries. (2291)

### Maintenance and fixes

Expand Down
6 changes: 4 additions & 2 deletions arviz/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Code for loading and manipulating data structures."""
from .base import CoordSpec, DimSpec, dict_to_dataset, numpy_to_data_array
from .base import CoordSpec, DimSpec, dict_to_dataset, numpy_to_data_array, pytree_to_dataset
from .converters import convert_to_dataset, convert_to_inference_data
from .datasets import clear_data_home, list_datasets, load_arviz_data
from .inference_data import InferenceData, concat
from .io_beanmachine import from_beanmachine
from .io_cmdstan import from_cmdstan
from .io_cmdstanpy import from_cmdstanpy
from .io_datatree import from_datatree, to_datatree
from .io_dict import from_dict
from .io_dict import from_dict, from_pytree
from .io_emcee import from_emcee
from .io_json import from_json, to_json
from .io_netcdf import from_netcdf, to_netcdf
Expand Down Expand Up @@ -38,10 +38,12 @@
"from_cmdstanpy",
"from_datatree",
"from_dict",
"from_pytree",
"from_json",
"from_pyro",
"from_numpyro",
"from_netcdf",
"pytree_to_dataset",
"to_datatree",
"to_json",
"to_netcdf",
Expand Down
110 changes: 100 additions & 10 deletions arviz/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union

import numpy as np
import tree
import xarray as xr

try:
Expand Down Expand Up @@ -67,6 +68,48 @@ def wrapped(cls: RequiresArgTypeT) -> Optional[RequiresReturnTypeT]:
return wrapped


def _yield_flat_up_to(shallow_tree, input_tree, path=()):
"""Yields (path, value) pairs of input_tree flattened up to shallow_tree.
Adapted from dm-tree (https://github.com/google-deepmind/tree) to allow
lists as leaves.
Args:
shallow_tree: Nested structure. Traverse no further than its leaf nodes.
input_tree: Nested structure. Return the paths and values from this tree.
Must have the same upper structure as shallow_tree.
path: Tuple. Optional argument, only used when recursing. The path from the
root of the original shallow_tree, down to the root of the shallow_tree
arg of this recursive call.
Yields:
Pairs of (path, value), where path the tuple path of a leaf node in
shallow_tree, and value is the value of the corresponding node in
input_tree.
"""
# pylint: disable=protected-access
if isinstance(shallow_tree, tree._TEXT_OR_BYTES) or not (
isinstance(shallow_tree, tree.collections_abc.Mapping)
or tree._is_namedtuple(shallow_tree)
or tree._is_attrs(shallow_tree)
):
yield (path, input_tree)
else:
input_tree = dict(tree._yield_sorted_items(input_tree))
for shallow_key, shallow_subtree in tree._yield_sorted_items(shallow_tree):
subpath = path + (shallow_key,)
input_subtree = input_tree[shallow_key]
for leaf_path, leaf_value in _yield_flat_up_to(
shallow_subtree, input_subtree, path=subpath
):
yield (leaf_path, leaf_value)
# pylint: enable=protected-access


def _flatten_with_path(structure):
return list(_yield_flat_up_to(structure, structure))


def generate_dims_coords(
shape,
var_name,
Expand Down Expand Up @@ -255,7 +298,7 @@ def numpy_to_data_array(
return xr.DataArray(ary, coords=coords, dims=dims)


def dict_to_dataset(
def pytree_to_dataset(
data,
*,
attrs=None,
Expand All @@ -266,42 +309,86 @@ def dict_to_dataset(
index_origin=None,
skip_event_dims=None,
):
"""Convert a dictionary of numpy arrays to an xarray.Dataset.
"""Convert a dictionary or pytree of numpy arrays to an xarray.Dataset.
See https://jax.readthedocs.io/en/latest/pytrees.html for what a pytree is, but
this inclues at least dictionaries and tuple types.
Parameters
----------
data : dict[str] -> ndarray
data : dict of {str : array_like or dict} or pytree
Data to convert. Keys are variable names.
attrs : dict
attrs : dict, optional
Json serializable metadata to attach to the dataset, in addition to defaults.
library : module
library : module, optional
Library used for performing inference. Will be attached to the attrs metadata.
coords : dict[str] -> ndarray
coords : dict of {str : ndarray}, optional
Coordinates for the dataset
dims : dict[str] -> list[str]
dims : dict of {str : list of str}, optional
Dimensions of each variable. The keys are variable names, values are lists of
coordinates.
default_dims : list of str, optional
Passed to :py:func:`numpy_to_data_array`
index_origin : int, optional
Passed to :py:func:`numpy_to_data_array`
skip_event_dims : bool
skip_event_dims : bool, optional
If True, cut extra dims whenever present to match the shape of the data.
Necessary for PPLs which have the same name in both observed data and log
likelihood groups, to account for their different shapes when observations are
multivariate.
Returns
-------
xr.Dataset
xarray.Dataset
In case of nested pytrees, the variable name will be a tuple of individual names.
Notes
-----
This function is available through two aliases: ``dict_to_dataset`` or ``pytree_to_dataset``.
Examples
--------
dict_to_dataset({'x': np.random.randn(4, 100), 'y': np.random.rand(4, 100)})
Convert a dictionary with two 2D variables to a Dataset.
.. ipython::
In [1]: import arviz as az
...: import numpy as np
...: az.dict_to_dataset({'x': np.random.randn(4, 100), 'y': np.random.rand(4, 100)})
Note that unlike the :class:`xarray.Dataset` constructor, ArviZ has added extra
information to the generated Dataset such as default dimension names for sampled
dimensions and some attributes.
The function is also general enough to work on pytrees such as nested dictionaries:
.. ipython::
In [1]: az.pytree_to_dataset({'top': {'second': 1.}, 'top2': 1.})
which has two variables (as many as leafs) named ``('top', 'second')`` and ``top2``.
Dimensions and co-ordinates can be defined as usual:
.. ipython::
In [1]: datadict = {
...: "top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)},
...: "d": np.random.randn(100),
...: }
...: az.dict_to_dataset(
...: datadict,
...: coords={"c": np.arange(10)},
...: dims={("top", "b"): ["c"]}
...: )
"""
if dims is None:
dims = {}
try:
data = {k[0] if len(k) == 1 else k: v for k, v in _flatten_with_path(data)}
except TypeError: # probably unsortable keys -- the function will still work if
pass # it is an honest dictionary.

data_vars = {
key: numpy_to_data_array(
Expand All @@ -318,6 +405,9 @@ def dict_to_dataset(
return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))


dict_to_dataset = pytree_to_dataset


def make_attrs(attrs=None, library=None):
"""Make standard attributes to attach to xarray datasets.
Expand Down
4 changes: 4 additions & 0 deletions arviz/data/converters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""High level conversion functions."""
import numpy as np
import tree
import xarray as xr

from .base import dict_to_dataset
Expand Down Expand Up @@ -105,6 +106,8 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
dataset = obj.to_dataset()
elif isinstance(obj, dict):
dataset = dict_to_dataset(obj, coords=coords, dims=dims)
elif tree.is_nested(obj) and not isinstance(obj, (list, tuple)):
dataset = dict_to_dataset(obj, coords=coords, dims=dims)
elif isinstance(obj, np.ndarray):
dataset = dict_to_dataset({"x": obj}, coords=coords, dims=dims)
elif isinstance(obj, (list, tuple)) and isinstance(obj[0], str) and obj[0].endswith(".csv"):
Expand All @@ -118,6 +121,7 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
"xarray dataarray",
"xarray dataset",
"dict",
"pytree",
"netcdf filename",
"numpy array",
"pystan fit",
Expand Down
3 changes: 3 additions & 0 deletions arviz/data/io_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,3 +458,6 @@ def from_dict(
attrs=attrs,
**kwargs,
).to_inference_data()


from_pytree = from_dict
2 changes: 1 addition & 1 deletion arviz/plots/backends/matplotlib/pairplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def plot_pair(
if reference_values:
x_name = flat_var_names[i]
y_name = flat_var_names[j + not_marginals]
if x_name and y_name not in difference:
if (x_name not in difference) and (y_name not in difference):
ax[j, i].plot(
reference_values_copy[x_name],
reference_values_copy[y_name],
Expand Down
14 changes: 14 additions & 0 deletions arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,20 @@ def test_dict_to_dataset():
assert set(dataset.b.coords) == {"chain", "draw", "c"}


def test_nested_dict_to_dataset():
datadict = {
"top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)},
"d": np.random.randn(100),
}
dataset = convert_to_dataset(datadict, coords={"c": np.arange(10)}, dims={("top", "b"): ["c"]})
assert set(dataset.data_vars) == {("top", "a"), ("top", "b"), "d"}
assert set(dataset.coords) == {"chain", "draw", "c"}

assert set(dataset[("top", "a")].coords) == {"chain", "draw"}
assert set(dataset[("top", "b")].coords) == {"chain", "draw", "c"}
assert set(dataset.d.coords) == {"chain", "draw"}


def test_dict_to_dataset_event_dims_error():
datadict = {"a": np.random.randn(1, 100, 10)}
coords = {"b": np.arange(10), "c": ["x", "y", "z"]}
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ numpy>=1.22.0,<2.0
scipy>=1.8.0
packaging
pandas>=1.4.0
dm-tree>=0.1.8
xarray>=0.21.0
h5netcdf>=1.0.2
typing_extensions>=4.1.0
Expand Down

0 comments on commit 29ca5a1

Please sign in to comment.