Skip to content

Commit

Permalink
Merge pull request #82 from normal-computing/tree-utils
Browse files Browse the repository at this point in the history
Split tree_utils from utils
  • Loading branch information
SamDuffield authored Apr 30, 2024
2 parents 1d9c398 + 3371847 commit 2b0874a
Show file tree
Hide file tree
Showing 14 changed files with 526 additions and 514 deletions.
1 change: 1 addition & 0 deletions docs/api/tree_utils.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: posteriors.tree_utils
4 changes: 1 addition & 3 deletions docs/tutorials/ekf_premier_league.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ We'll use the [football-data.co.uk](https://www.football-data.co.uk/englandm.php
data = data.dropna()
data["Timestamp"] = pd.to_datetime(data["Date"], dayfirst=True)
data["Timestamp"] = pd.to_datetime(data["Timestamp"], unit="D")
data["TimestampDays"] = (
(data["Timestamp"] - origin_date).dt.days.astype(int)
)
data["TimestampDays"] = (data["Timestamp"] - origin_date).dt.days.astype(int)

players_arr = pd.unique(pd.concat([data["HomeTeam"], data["AwayTeam"]]))
players_arr.sort()
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ nav:
- Diag: api/vi/diag.md
- api/optim.md
- TorchOpt: api/torchopt.md
- api/tree_utils.md
- api/types.md
- api/utils.md

Expand Down
23 changes: 12 additions & 11 deletions posteriors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,21 @@
from posteriors.utils import cg
from posteriors.utils import diag_normal_log_prob
from posteriors.utils import diag_normal_sample
from posteriors.utils import tree_size
from posteriors.utils import tree_extract
from posteriors.utils import tree_insert
from posteriors.utils import tree_insert_
from posteriors.utils import extract_requires_grad
from posteriors.utils import insert_requires_grad
from posteriors.utils import insert_requires_grad_
from posteriors.utils import extract_requires_grad_and_func
from posteriors.utils import inplacify
from posteriors.utils import tree_map_inplacify_
from posteriors.utils import flexi_tree_map
from posteriors.utils import per_samplify
from posteriors.utils import is_scalar

from posteriors.tree_utils import tree_size
from posteriors.tree_utils import tree_extract
from posteriors.tree_utils import tree_insert
from posteriors.tree_utils import tree_insert_
from posteriors.tree_utils import extract_requires_grad
from posteriors.tree_utils import insert_requires_grad
from posteriors.tree_utils import insert_requires_grad_
from posteriors.tree_utils import extract_requires_grad_and_func
from posteriors.tree_utils import inplacify
from posteriors.tree_utils import tree_map_inplacify_
from posteriors.tree_utils import flexi_tree_map

import logging

logger = logging.getLogger("torch.distributed.elastic.multiprocessing.redirects")
Expand Down
2 changes: 1 addition & 1 deletion posteriors/ekf/diag_fisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from dataclasses import dataclass

from posteriors.types import TensorTree, Transform, LogProbFn, TransformState
from posteriors.tree_utils import flexi_tree_map
from posteriors.utils import (
diag_normal_sample,
flexi_tree_map,
per_samplify,
is_scalar,
CatchAuxError,
Expand Down
2 changes: 1 addition & 1 deletion posteriors/laplace/dense_fisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from optree.integration.torch import tree_ravel

from posteriors.types import TensorTree, Transform, LogProbFn, Tensor, TransformState
from posteriors.tree_utils import tree_size
from posteriors.utils import (
per_samplify,
tree_size,
empirical_fisher,
is_scalar,
CatchAuxError,
Expand Down
2 changes: 1 addition & 1 deletion posteriors/laplace/diag_fisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from dataclasses import dataclass

from posteriors.types import TensorTree, Transform, LogProbFn, TransformState
from posteriors.tree_utils import flexi_tree_map
from posteriors.utils import (
diag_normal_sample,
flexi_tree_map,
per_samplify,
is_scalar,
CatchAuxError,
Expand Down
3 changes: 2 additions & 1 deletion posteriors/sgmcmc/sghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from dataclasses import dataclass

from posteriors.types import TensorTree, Transform, LogProbFn, TransformState
from posteriors.utils import flexi_tree_map, is_scalar, CatchAuxError
from posteriors.tree_utils import flexi_tree_map
from posteriors.utils import is_scalar, CatchAuxError


def build(
Expand Down
3 changes: 2 additions & 1 deletion posteriors/sgmcmc/sgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from dataclasses import dataclass

from posteriors.types import TensorTree, Transform, LogProbFn, TransformState
from posteriors.utils import flexi_tree_map, CatchAuxError
from posteriors.tree_utils import flexi_tree_map
from posteriors.utils import CatchAuxError


def build(
Expand Down
278 changes: 278 additions & 0 deletions posteriors/tree_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
from typing import Callable, Tuple
import torch
from optree import tree_map, tree_map_, tree_reduce

from posteriors.types import TensorTree


def tree_size(tree: TensorTree) -> int:
"""Returns the total number of elements in a PyTree.
Not the number of leaves, but the total number of elements for all tensors in the
tree.
Args:
tree: A PyTree of tensors.
Returns:
Number of elements in the PyTree.
"""

def ensure_tensor(x):
return x if isinstance(x, torch.Tensor) else torch.tensor(x)

return tree_reduce(torch.add, tree_map(lambda x: ensure_tensor(x).numel(), tree))


def tree_extract(f: Callable[[torch.tensor], bool], tree: TensorTree) -> TensorTree:
"""Extracts values from a PyTree where f returns True.
False values are replaced with empty tensors.
Args:
f: A function that takes a PyTree element and returns True or False.
tree: A PyTree.
Returns:
A PyTree with the same structure as tree where f returns True.
"""
return tree_map(lambda x: x if f(x) else torch.tensor([], device=x.device), tree)


def tree_insert(
f: Callable[[torch.tensor], bool], full_tree: TensorTree, sub_tree: TensorTree
) -> TensorTree:
"""Inserts sub_tree into full_tree where full_tree tensors evaluate f to True.
Both PyTrees must have the same structure.
Args:
f: A function that takes a PyTree element and returns True or False.
full_tree: A PyTree to insert sub_tree into.
sub_tree: A PyTree to insert into full_tree.
Returns:
A PyTree with sub_tree inserted into full_tree.
"""
return tree_map(
lambda sub, full: sub if f(full) else full,
sub_tree,
full_tree,
)


def tree_insert_(
f: Callable[[torch.tensor], bool], full_tree: TensorTree, sub_tree: TensorTree
) -> TensorTree:
"""Inserts sub_tree into full_tree in-place where full_tree tensors evaluate
f to True. Both PyTrees must have the same structure.
Args:
f: A function that takes a PyTree element and returns True or False.
full_tree: A PyTree to insert sub_tree into.
sub_tree: A PyTree to insert into full_tree.
Returns:
A pointer to full_tree with sub_tree inserted.
"""

def insert_(full, sub):
if f(full):
full.data = sub.data

return tree_map_(insert_, full_tree, sub_tree)


def extract_requires_grad(tree: TensorTree) -> TensorTree:
"""Extracts only parameters that require gradients.
Args:
tree: A PyTree of tensors.
Returns:
A PyTree of tensors that require gradients.
"""
return tree_extract(lambda x: x.requires_grad, tree)


def insert_requires_grad(full_tree: TensorTree, sub_tree: TensorTree) -> TensorTree:
"""Inserts sub_tree into full_tree where full_tree tensors requires_grad.
Both PyTrees must have the same structure.
Args:
full_tree: A PyTree to insert sub_tree into.
sub_tree: A PyTree to insert into full_tree.
Returns:
A PyTree with sub_tree inserted into full_tree.
"""
return tree_insert(lambda x: x.requires_grad, full_tree, sub_tree)


def insert_requires_grad_(full_tree: TensorTree, sub_tree: TensorTree) -> TensorTree:
"""Inserts sub_tree into full_tree in-place where full_tree tensors requires_grad.
Both PyTrees must have the same structure.
Args:
full_tree: A PyTree to insert sub_tree into.
sub_tree: A PyTree to insert into full_tree.
Returns:
A pointer to full_tree with sub_tree inserted.
"""
return tree_insert_(lambda x: x.requires_grad, full_tree, sub_tree)


def extract_requires_grad_and_func(
tree: TensorTree, func: Callable, inplace: bool = False
) -> Tuple[TensorTree, Callable]:
"""Extracts only parameters that require gradients and converts a function
that takes the full parameter tree (in its first argument)
into one that takes the subtree.
Args:
tree: A PyTree of tensors.
func: A function that takes tree in its first argument.
inplace: Whether to modify the tree inplace or not whe the new function
is called.
Returns:
A PyTree of tensors that require gradients and a modified func that takes the
subtree structure rather than full tree in its first argument.
"""
subtree = extract_requires_grad(tree)

insert = insert_requires_grad_ if inplace else insert_requires_grad

def subfunc(subtree, *args, **kwargs):
return func(insert(tree, subtree), *args, **kwargs)

return subtree, subfunc


def inplacify(func: Callable) -> Callable:
"""Converts a function that takes a tensor as its first argument
into one that takes the same arguments but modifies the first argument
tensor in-place with the output of the function.
Args:
func: A function that takes a tensor as its first argument and a returns
a modified version of said tensor.
Returns:
A function that takes a tensor as its first argument and modifies it
in-place.
"""

def func_(tens, *args, **kwargs):
tens.data = func(tens, *args, **kwargs)
return tens

return func_


def tree_map_inplacify_(
func: Callable,
tree: TensorTree,
*rests: TensorTree,
is_leaf: Callable[[TensorTree], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = "",
) -> TensorTree:
"""Applies a pure function to each tensor in a PyTree in-place.
Like [optree.tree_map_](https://optree.readthedocs.io/en/latest/ops.html#optree.tree_map_)
but takes a pure function as input (and takes replaces its first argument with its
output in-place) rather than a side-effect function.
Args:
func: A function that takes a tensor as its first argument and a returns
a modified version of said tensor.
tree (pytree): A pytree to be mapped over, with each leaf providing the first
positional argument to function ``func``.
rests (tuple of pytree): A tuple of pytrees, each of which has the same
structure as ``tree`` or has ``tree`` as a prefix.
is_leaf (callable, optional): An optionally specified function that will be
called at each flattening step. It should return a boolean, with
`True` stopping the traversal and the whole subtree being treated as a
leaf, and `False` indicating the flattening should traverse the
current object.
none_is_leaf (bool, optional): Whether to treat `None` as a leaf. If
`False`, `None` is a non-leaf node with arity 0. Thus `None` is contained in
the treespec rather than in the leaves list and `None` will be remain in the
result pytree. (default: `False`)
namespace (str, optional): The registry namespace used for custom pytree node
types. (default: :const:`''`, i.e., the global namespace)
Returns:
The original ``tree`` with the value at each leaf is given by the side-effect of
function ``func(x, *xs)`` (not the return value) where ``x`` is the value at
the corresponding leaf in ``tree`` and ``xs`` is the tuple of values at
values at corresponding nodes in ``rests``.
"""
return tree_map_(
inplacify(func),
tree,
*rests,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)


def flexi_tree_map(
func: Callable,
tree: TensorTree,
*rests: TensorTree,
inplace: bool = False,
is_leaf: Callable[[TensorTree], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = "",
) -> TensorTree:
"""Applies a pure function to each tensor in a PyTree, with inplace argument.
```
out_tensor = func(tensor, *rest_tensors)
```
where `out_tensor` is of the same shape as `tensor`.
Therefore
```
out_tree = func(tree, *rests, inplace=True)
```
will return `out_tree` a pointer to the original `tree` with leaves (tensors)
modified in place.
If `inplace=False`, `flexi_tree_map` is equivalent to [`optree.tree_map`](https://optree.readthedocs.io/en/latest/ops.html#optree.tree_map)
and returns a new tree.
Args:
func: A pure function that takes a tensor as its first argument and a returns
a modified version of said tensor.
tree (pytree): A pytree to be mapped over, with each leaf providing the first
positional argument to function ``func``.
rests (tuple of pytree): A tuple of pytrees, each of which has the same
structure as ``tree`` or has ``tree`` as a prefix.
inplace (bool, optional): Whether to modify the tree in-place or not.
is_leaf (callable, optional): An optionally specified function that will be
called at each flattening step. It should return a boolean, with `True`
stopping the traversal and the whole subtree being treated as a leaf, and
`False` indicating the flattening should traverse the current object.
none_is_leaf (bool, optional): Whether to treat `None` as a leaf. If `False`,
`None` is a non-leaf node with arity 0. Thus `None` is contained in the
treespec rather than in the leaves list and `None` will be remain in the
result pytree. (default: `False`)
namespace (str, optional): The registry namespace used for custom pytree node
types. (default: :const:`''`, i.e., the global namespace)
Returns:
Either the original tree modified in-place or a new tree depending on the
`inplace` argument.
"""
tm = tree_map_inplacify_ if inplace else tree_map
return tm(
func,
tree,
*rests,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
Loading

0 comments on commit 2b0874a

Please sign in to comment.