-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ENH] Add functional submodule (#75)
* Add functional submodule * update lock file --------- Signed-off-by: Adam Li <[email protected]>
- Loading branch information
Showing
7 changed files
with
154 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,3 +20,4 @@ | |
from . import classes | ||
from . import networkx | ||
from . import simulate | ||
from . import functional |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .linear import make_graph_linear_gaussian |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
from typing import Callable, List, Optional | ||
|
||
import networkx as nx | ||
import numpy as np | ||
|
||
|
||
def make_graph_linear_gaussian( | ||
G: nx.DiGraph, | ||
node_mean_lims: Optional[List[float]] = None, | ||
node_std_lims: Optional[List[float]] = None, | ||
edge_functions: List[Callable[[float], float]] = None, | ||
edge_weight_lims: Optional[List[float]] = None, | ||
random_state=None, | ||
): | ||
r"""Convert an existing DAG to a linear Gaussian graphical model. | ||
All nodes are sampled from a normal distribution with parametrizations | ||
defined uniformly at random between the limits set by the input parameters. | ||
The edges apply then a weight and a function based on the inputs in an additive fashion. | ||
For node :math:`X_i`, we have: | ||
.. math:: | ||
X_i = \\sum_{j \in parents} w_j f_j(X_j) + \\epsilon_i | ||
where: | ||
- :math:`\\epsilon_i \sim N(\mu_i, \sigma_i)`, where :math:`\mu_i` is sampled | ||
uniformly at random from `node_mean_lims` and :math:`\sigma_i` is sampled | ||
uniformly at random from `node_std_lims`. | ||
- :math:`w_j \sim U(\\text{edge_weight_lims})` | ||
- :math:`f_j` is a function sampled uniformly at random | ||
from `edge_functions` | ||
Parameters | ||
---------- | ||
G : NetworkX DiGraph | ||
The graph to sample data from. The graph will be modified in-place | ||
to get the weights and functions of the edges. | ||
node_mean_lims : Optional[List[float]], optional | ||
The lower and upper bounds of the mean of the Gaussian random variable, by default None, | ||
which defaults to a mean of 0. | ||
node_std_lims : Optional[List[float]], optional | ||
The lower and upper bounds of the std of the Gaussian random variable, by default None, | ||
which defaults to a std of 1. | ||
edge_functions : List[Callable[float]], optional | ||
The set of edge functions that take in an iid sample from the parent and computes | ||
a transformation (possibly nonlinear), such as ``(lambda x: x**2, lambda x: x)``, | ||
by default None, which defaults to the identity function ``lambda x: x``. | ||
edge_weight_lims : Optional[List[float]], optional | ||
The lower and upper bounds of the edge weight, by default None, | ||
which defaults to a weight of 1. | ||
random_state : int, optional | ||
Random seed, by default None. | ||
Returns | ||
------- | ||
G : NetworkX DiGraph | ||
NetworkX graph with the edge weights and functions set with node attributes | ||
set with ``'parent_functions'``, and ``'gaussian_noise_function'``. Moreover | ||
the graph attribute ``'linear_gaussian'`` is set to ``True``. | ||
""" | ||
if not nx.is_directed_acyclic_graph(G): | ||
raise ValueError("The input graph must be a DAG.") | ||
rng = np.random.default_rng(random_state) | ||
|
||
if node_mean_lims is None: | ||
node_mean_lims = [0, 0] | ||
elif len(node_mean_lims) != 2: | ||
raise ValueError("node_mean_lims must be a list of length 2.") | ||
if node_std_lims is None: | ||
node_std_lims = [1, 1] | ||
elif len(node_std_lims) != 2: | ||
raise ValueError("node_std_lims must be a list of length 2.") | ||
if edge_functions is None: | ||
edge_functions = [lambda x: x] | ||
if edge_weight_lims is None: | ||
edge_weight_lims = [1, 1] | ||
elif len(edge_weight_lims) != 2: | ||
raise ValueError("edge_weight_lims must be a list of length 2.") | ||
|
||
# Create list of topologically sorted nodes | ||
top_sort_idx = list(nx.topological_sort(G)) | ||
|
||
for node_idx in top_sort_idx: | ||
# get all parents | ||
parents = sorted(list(G.predecessors(node_idx))) | ||
|
||
# sample noise | ||
mean = rng.uniform(low=node_mean_lims[0], high=node_mean_lims[1]) | ||
std = rng.uniform(low=node_std_lims[0], high=node_std_lims[1]) | ||
|
||
# sample weight and edge function for each parent | ||
node_function = dict() | ||
for parent in parents: | ||
weight = rng.uniform(low=edge_weight_lims[0], high=edge_weight_lims[1]) | ||
func = rng.choice(edge_functions) | ||
node_function.update({parent: {"weight": weight, "func": func}}) | ||
|
||
# set the node attribute "functions" to hold the weight and function wrt each parent | ||
nx.set_node_attributes(G, {node_idx: node_function}, "parent_functions") | ||
nx.set_node_attributes(G, {node_idx: {"mean": mean, "std": std}}, "gaussian_noise_function") | ||
G.graph["linear_gaussian"] = True | ||
return G |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import networkx as nx | ||
import pytest | ||
|
||
from pywhy_graphs.functional import make_graph_linear_gaussian | ||
from pywhy_graphs.simulate import simulate_random_er_dag | ||
|
||
|
||
def test_make_linear_gaussian_graph(): | ||
G = simulate_random_er_dag(n_nodes=5, seed=12345, ensure_acyclic=True) | ||
|
||
G = make_graph_linear_gaussian(G, random_state=12345) | ||
|
||
assert all(key in nx.get_node_attributes(G, "parent_functions") for key in G.nodes) | ||
assert all(key in nx.get_node_attributes(G, "gaussian_noise_function") for key in G.nodes) | ||
|
||
|
||
def test_make_linear_gaussian_graph_errors(): | ||
G = simulate_random_er_dag(n_nodes=2, seed=12345, ensure_acyclic=True) | ||
|
||
with pytest.raises(ValueError, match="must be a list of length 2."): | ||
G = make_graph_linear_gaussian(G, node_mean_lims=[0], random_state=12345) | ||
|
||
with pytest.raises(ValueError, match="must be a list of length 2."): | ||
G = make_graph_linear_gaussian(G, node_std_lims=[0], random_state=12345) | ||
|
||
with pytest.raises(ValueError, match="must be a list of length 2."): | ||
G = make_graph_linear_gaussian(G, edge_weight_lims=[0], random_state=12345) | ||
|
||
with pytest.raises(ValueError, match="The input graph must be a DAG."): | ||
G = make_graph_linear_gaussian( | ||
nx.cycle_graph(4, create_using=nx.DiGraph), random_state=12345 | ||
) |