diff --git a/docs/api.rst b/docs/api.rst index 0dc5fa431..64b035e75 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -147,6 +147,17 @@ a SCM and their data starting from the causal graph. simulate.simulate_data_from_var simulate.simulate_var_process_from_summary_graph +Converting graphs to functional models +====================================== +An experimental submodule for converting graphs to functional models, such as +linear structural equation Gaussian models (SEMs). + +.. currentmodule:: pywhy_graphs.functional + +.. autosummary:: + :toctree: generated/ + + make_graph_linear_gaussian Visualization of causal graphs ============================== diff --git a/docs/whats_new/v0.1.rst b/docs/whats_new/v0.1.rst index 52bea36b0..0288ba40e 100644 --- a/docs/whats_new/v0.1.rst +++ b/docs/whats_new/v0.1.rst @@ -43,6 +43,7 @@ Changelog - |Feature| Implement export/import functions to go to/from pywhy-graphs to pcalg and tetrad, by `Adam Li`_ (:pr:`60`) - |Feature| Implement export/import functions to go to/from pywhy-graphs to ananke-causal, by `Jaron Lee`_ (:pr:`63`) - |Feature| Implement pre-commit hooks for development, by `Jaron Lee`_ (:pr:`68`) +- |Feature| Implement a new submodule for converting graphs to a functional model, with :func:`pywhy_graphs.functional.make_graph_linear_gaussian`, by `Adam Li`_ (:pr:`75`) Code and Documentation Contributors ----------------------------------- diff --git a/poetry.lock b/poetry.lock index d8a469a6f..4aa1e26f2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3766,14 +3766,14 @@ socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] [[package]] name = "virtualenv" -version = "20.22.0" +version = "20.23.0" description = "Virtual Python Environment builder" category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "virtualenv-20.22.0-py3-none-any.whl", hash = "sha256:48fd3b907b5149c5aab7c23d9790bea4cac6bc6b150af8635febc4cfeab1275a"}, - {file = "virtualenv-20.22.0.tar.gz", hash = "sha256:278753c47aaef1a0f14e6db8a4c5e1e040e90aea654d0fc1dc7e0d8a42616cc3"}, + {file = "virtualenv-20.23.0-py3-none-any.whl", hash = "sha256:6abec7670e5802a528357fdc75b26b9f57d5d92f29c5462ba0fbe45feacc685e"}, + {file = "virtualenv-20.23.0.tar.gz", hash = "sha256:a85caa554ced0c0afbd0d638e7e2d7b5f92d23478d05d17a76daeac8f279f924"}, ] [package.dependencies] @@ -3783,7 +3783,7 @@ platformdirs = ">=3.2,<4" [package.extras] docs = ["furo (>=2023.3.27)", "proselint (>=0.13)", "sphinx (>=6.1.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=22.12)"] -test = ["covdefaults (>=2.3)", "coverage (>=7.2.3)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.3.1)", "pytest-env (>=0.8.1)", "pytest-freezegun (>=0.4.2)", "pytest-mock (>=3.10)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)"] +test = ["covdefaults (>=2.3)", "coverage (>=7.2.3)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.3.1)", "pytest-env (>=0.8.1)", "pytest-freezegun (>=0.4.2)", "pytest-mock (>=3.10)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=67.7.1)", "time-machine (>=2.9)"] [[package]] name = "watchdog" diff --git a/pywhy_graphs/__init__.py b/pywhy_graphs/__init__.py index 6b1b212e3..92d9e806b 100644 --- a/pywhy_graphs/__init__.py +++ b/pywhy_graphs/__init__.py @@ -20,3 +20,4 @@ from . import classes from . import networkx from . import simulate +from . import functional diff --git a/pywhy_graphs/functional/__init__.py b/pywhy_graphs/functional/__init__.py new file mode 100644 index 000000000..84c56d73d --- /dev/null +++ b/pywhy_graphs/functional/__init__.py @@ -0,0 +1 @@ +from .linear import make_graph_linear_gaussian diff --git a/pywhy_graphs/functional/linear.py b/pywhy_graphs/functional/linear.py new file mode 100644 index 000000000..4cd51df11 --- /dev/null +++ b/pywhy_graphs/functional/linear.py @@ -0,0 +1,104 @@ +from typing import Callable, List, Optional + +import networkx as nx +import numpy as np + + +def make_graph_linear_gaussian( + G: nx.DiGraph, + node_mean_lims: Optional[List[float]] = None, + node_std_lims: Optional[List[float]] = None, + edge_functions: List[Callable[[float], float]] = None, + edge_weight_lims: Optional[List[float]] = None, + random_state=None, +): + r"""Convert an existing DAG to a linear Gaussian graphical model. + + All nodes are sampled from a normal distribution with parametrizations + defined uniformly at random between the limits set by the input parameters. + The edges apply then a weight and a function based on the inputs in an additive fashion. + For node :math:`X_i`, we have: + + .. math:: + + X_i = \\sum_{j \in parents} w_j f_j(X_j) + \\epsilon_i + + where: + + - :math:`\\epsilon_i \sim N(\mu_i, \sigma_i)`, where :math:`\mu_i` is sampled + uniformly at random from `node_mean_lims` and :math:`\sigma_i` is sampled + uniformly at random from `node_std_lims`. + - :math:`w_j \sim U(\\text{edge_weight_lims})` + - :math:`f_j` is a function sampled uniformly at random + from `edge_functions` + + Parameters + ---------- + G : NetworkX DiGraph + The graph to sample data from. The graph will be modified in-place + to get the weights and functions of the edges. + node_mean_lims : Optional[List[float]], optional + The lower and upper bounds of the mean of the Gaussian random variable, by default None, + which defaults to a mean of 0. + node_std_lims : Optional[List[float]], optional + The lower and upper bounds of the std of the Gaussian random variable, by default None, + which defaults to a std of 1. + edge_functions : List[Callable[float]], optional + The set of edge functions that take in an iid sample from the parent and computes + a transformation (possibly nonlinear), such as ``(lambda x: x**2, lambda x: x)``, + by default None, which defaults to the identity function ``lambda x: x``. + edge_weight_lims : Optional[List[float]], optional + The lower and upper bounds of the edge weight, by default None, + which defaults to a weight of 1. + random_state : int, optional + Random seed, by default None. + + Returns + ------- + G : NetworkX DiGraph + NetworkX graph with the edge weights and functions set with node attributes + set with ``'parent_functions'``, and ``'gaussian_noise_function'``. Moreover + the graph attribute ``'linear_gaussian'`` is set to ``True``. + """ + if not nx.is_directed_acyclic_graph(G): + raise ValueError("The input graph must be a DAG.") + rng = np.random.default_rng(random_state) + + if node_mean_lims is None: + node_mean_lims = [0, 0] + elif len(node_mean_lims) != 2: + raise ValueError("node_mean_lims must be a list of length 2.") + if node_std_lims is None: + node_std_lims = [1, 1] + elif len(node_std_lims) != 2: + raise ValueError("node_std_lims must be a list of length 2.") + if edge_functions is None: + edge_functions = [lambda x: x] + if edge_weight_lims is None: + edge_weight_lims = [1, 1] + elif len(edge_weight_lims) != 2: + raise ValueError("edge_weight_lims must be a list of length 2.") + + # Create list of topologically sorted nodes + top_sort_idx = list(nx.topological_sort(G)) + + for node_idx in top_sort_idx: + # get all parents + parents = sorted(list(G.predecessors(node_idx))) + + # sample noise + mean = rng.uniform(low=node_mean_lims[0], high=node_mean_lims[1]) + std = rng.uniform(low=node_std_lims[0], high=node_std_lims[1]) + + # sample weight and edge function for each parent + node_function = dict() + for parent in parents: + weight = rng.uniform(low=edge_weight_lims[0], high=edge_weight_lims[1]) + func = rng.choice(edge_functions) + node_function.update({parent: {"weight": weight, "func": func}}) + + # set the node attribute "functions" to hold the weight and function wrt each parent + nx.set_node_attributes(G, {node_idx: node_function}, "parent_functions") + nx.set_node_attributes(G, {node_idx: {"mean": mean, "std": std}}, "gaussian_noise_function") + G.graph["linear_gaussian"] = True + return G diff --git a/pywhy_graphs/functional/tests/test_linear.py b/pywhy_graphs/functional/tests/test_linear.py new file mode 100644 index 000000000..e73a25f1c --- /dev/null +++ b/pywhy_graphs/functional/tests/test_linear.py @@ -0,0 +1,32 @@ +import networkx as nx +import pytest + +from pywhy_graphs.functional import make_graph_linear_gaussian +from pywhy_graphs.simulate import simulate_random_er_dag + + +def test_make_linear_gaussian_graph(): + G = simulate_random_er_dag(n_nodes=5, seed=12345, ensure_acyclic=True) + + G = make_graph_linear_gaussian(G, random_state=12345) + + assert all(key in nx.get_node_attributes(G, "parent_functions") for key in G.nodes) + assert all(key in nx.get_node_attributes(G, "gaussian_noise_function") for key in G.nodes) + + +def test_make_linear_gaussian_graph_errors(): + G = simulate_random_er_dag(n_nodes=2, seed=12345, ensure_acyclic=True) + + with pytest.raises(ValueError, match="must be a list of length 2."): + G = make_graph_linear_gaussian(G, node_mean_lims=[0], random_state=12345) + + with pytest.raises(ValueError, match="must be a list of length 2."): + G = make_graph_linear_gaussian(G, node_std_lims=[0], random_state=12345) + + with pytest.raises(ValueError, match="must be a list of length 2."): + G = make_graph_linear_gaussian(G, edge_weight_lims=[0], random_state=12345) + + with pytest.raises(ValueError, match="The input graph must be a DAG."): + G = make_graph_linear_gaussian( + nx.cycle_graph(4, create_using=nx.DiGraph), random_state=12345 + )