Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Var.plot_vars and Model.plot_vars #230

Merged
merged 11 commits into from
Feb 5, 2025
138 changes: 132 additions & 6 deletions liesel/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from collections.abc import Iterable
from copy import deepcopy
from types import MappingProxyType
from typing import IO, Any, TypeVar
from typing import IO, Any, Literal, TypeVar

import dill
import jax
Expand Down Expand Up @@ -282,11 +282,13 @@ def _all_nodes_and_vars(self) -> tuple[list[Node], list[Var]]:
return all_nodes, all_vars

@staticmethod
def _do_set_missing_names(nodes_or_vars: Iterable[NV], prefix: str) -> None:
def _do_set_missing_names(nodes_or_vars: Iterable[NV], prefix: str) -> list[str]:
"""Sets the missing names for the given nodes or variables."""
other = [nv.name for nv in nodes_or_vars if nv.name]
counter = -1

automatically_set_names = []

for nv in nodes_or_vars:
if not nv.name:
name = f"{prefix}{(counter := counter + 1)}"
Expand All @@ -296,13 +298,16 @@ def _do_set_missing_names(nodes_or_vars: Iterable[NV], prefix: str) -> None:

nv.name = name
other.append(name)
automatically_set_names.append(name)

return automatically_set_names

def _set_missing_names(self) -> GraphBuilder:
def _set_missing_names(self) -> dict[str, list[str]]:
"""Sets the missing node and variable names."""
nodes, _vars = self._all_nodes_and_vars()
self._do_set_missing_names(_vars, prefix="v")
self._do_set_missing_names(nodes, prefix="n")
return self
auto_var_names = self._do_set_missing_names(_vars, prefix="v")
auto_node_names = self._do_set_missing_names(nodes, prefix="n")
return {"vars": auto_var_names, "nodes": auto_node_names}

def add(
self, *args: Node | Var | GraphBuilder, to_float32: bool | None = None
Expand Down Expand Up @@ -1206,6 +1211,29 @@ def copy_nodes_and_vars(self) -> tuple[dict[str, Node], dict[str, Var]]:

return nodes, _vars

def node_parental_subgraph(self, *of: Node) -> nx.DiGraph:
"""
Returns a subgraph that consists of the input nodes and their parent nodes.
"""
nodes_to_include = set()
for node in of:
nodes_to_include.update(nx.ancestors(self.node_graph, node))
nodes_to_include.add(node)
subgraph = self.node_graph.subgraph(nodes_to_include)
return subgraph

def var_parental_subgraph(self, *of: Var) -> nx.DiGraph:
"""
Returns a subgraph that consists of the input variables and their parent
variables.
"""
nodes_to_include = set()
for node in of:
nodes_to_include.update(nx.ancestors(self.var_graph, node))
nodes_to_include.add(node)
subgraph = self.var_graph.subgraph(nodes_to_include)
return subgraph

@property
def log_lik(self) -> Array:
"""
Expand Down Expand Up @@ -1407,6 +1435,104 @@ def __repr__(self) -> str:
brackets = f"({len(self._nodes)} nodes, {len(self._vars)} vars)"
return type(self).__name__ + brackets

def plot_vars(
self,
show: bool = True,
save_path: str | None | IO = None,
width: int = 14,
height: int = 10,
prog: Literal[
"dot", "circo", "fdp", "neato", "osage", "patchwork", "sfdp", "twopi"
] = "dot",
):
"""
Plots the variables of this model.

Wraps :func:`~.viz.plot_vars`.

Parameters
----------
show
Whether to show the plot in a new window.
save_path
Path to save the plot. If not provided, the plot will not be saved.
width
Width of the plot in inches.
height
Height of the plot in inches.
prog
Layout parameter. Available layouts: circo, dot (the default), fdp, neato, \
osage, patchwork, sfdp, twopi.

See Also
--------
.Var.plot_vars : Plots the variables of the Liesel sub-model that terminates in
this variable.
.Var.plot_nodes : Plots the nodes of the Liesel sub-model that terminates in
this variable.
.Model.plot_vars : Plots the variables of a Liesel model.
.Model.plot_nodes : Plots the nodes of a Liesel model.
.viz.plot_vars : Plots the variables of a Liesel model.
.viz.plot_nodes : Plots the nodes of a Liesel model.
"""
return plot_vars(
self,
show=show,
save_path=save_path,
width=width,
height=height,
prog=prog,
)

def plot_nodes(
self,
show: bool = True,
save_path: str | None | IO = None,
width: int = 14,
height: int = 10,
prog: Literal[
"dot", "circo", "fdp", "neato", "osage", "patchwork", "sfdp", "twopi"
] = "dot",
):
"""
Plots the nodes of this model.

Wraps :func:`~.viz.plot_nodes`.

Parameters
----------
show
Whether to show the plot in a new window.
save_path
Path to save the plot. If not provided, the plot will not be saved.
width
Width of the plot in inches.
height
Height of the plot in inches.
prog
Layout parameter. Available layouts: circo, dot (the default), fdp, neato, \
osage, patchwork, sfdp, twopi.

See Also
--------
.Var.plot_vars : Plots the variables of the Liesel sub-model that terminates in
this variable.
.Var.plot_nodes : Plots the nodes of the Liesel sub-model that terminates in
this variable.
.Model.plot_vars : Plots the variables of a Liesel model.
.Model.plot_nodes : Plots the nodes of a Liesel model.
.viz.plot_vars : Plots the variables of a Liesel model.
.viz.plot_nodes : Plots the nodes of a Liesel model.
"""
return plot_nodes(
self,
show=show,
save_path=save_path,
width=width,
height=height,
prog=prog,
)


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Save and load models ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
198 changes: 197 additions & 1 deletion liesel/model/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,24 @@
from functools import wraps
from itertools import chain
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, NamedTuple, TypeGuard, TypeVar, Union
from typing import (
IO,
TYPE_CHECKING,
Any,
Literal,
NamedTuple,
TypeGuard,
TypeVar,
Union,
)

import tensorflow_probability.substrates.jax.bijectors as jb
import tensorflow_probability.substrates.jax.distributions as jd
import tensorflow_probability.substrates.numpy.bijectors as nb
import tensorflow_probability.substrates.numpy.distributions as nd

from ..distributions.nodist import NoDistribution
from .viz import plot_nodes, plot_vars

if TYPE_CHECKING:
from .model import Model
Expand Down Expand Up @@ -1949,6 +1959,192 @@ def weak(self) -> bool:
def __repr__(self) -> str:
return f'{type(self).__name__}(name="{self.name}")'

def _plot(
self, which: Literal["vars", "nodes"] = "vars", verbose: bool = False, **kwargs
) -> None:

if self.model is not None:
match which:
case "vars":
subgraph = self.model.var_parental_subgraph(self)
return plot_vars(subgraph, **kwargs)
case "nodes":
self_nodes = [self.value_node, self.dist_node, self.var_value_node]
filtered_nodes = [nd for nd in self_nodes if nd is not None]
jobrachem marked this conversation as resolved.
Show resolved Hide resolved
subgraph = self.model.node_parental_subgraph(*filtered_nodes)
return plot_nodes(subgraph, **kwargs)

from liesel.model import GraphBuilder

gb = GraphBuilder().add(self)
nodes, _vars = gb._all_nodes_and_vars()

automatically_set_names = gb._set_missing_names()
var_names = automatically_set_names["vars"]
node_names = automatically_set_names["nodes"]
if var_names:
if verbose:
names_ = f"The automatically assigned names are: {var_names}. "
else:
names_ = ""
logger.info(
f"Unnamed variables were temporarily named for plotting. {names_}"
"The names are reset"
" after plotting."
)
if node_names:
if verbose:
names_ = f"The automatically assigned names are: {node_names}. "
else:
names_ = ""
logger.info(
f"Unnamed nodes were temporarily named for plotting. {names_}"
"The names are reset"
" after plotting."
)

model = gb.build_model()

match which:
case "vars":
subgraph = model.var_parental_subgraph(self)
plot_vars(subgraph, **kwargs)
case "nodes":
self_nodes = [self.value_node, self.dist_node, self.var_value_node]
filtered_nodes = [nd for nd in self_nodes if nd is not None]
subgraph = model.node_parental_subgraph(*filtered_nodes)
plot_nodes(subgraph, **kwargs)

model.pop_nodes_and_vars()

vars_dict = {var_.name: var_ for var_ in _vars}
nodes_dict = {node.name: node for node in nodes}

for name in var_names:
vars_dict[name].name = ""

for name in node_names:
nodes_dict[name].name = ""

gb.nodes.clear()
gb.vars.clear()

def plot_vars(
self,
show: bool = True,
save_path: str | None | IO = None,
width: int = 14,
height: int = 10,
prog: Literal[
"dot", "circo", "fdp", "neato", "osage", "patchwork", "sfdp", "twopi"
] = "dot",
verbose: bool = False,
) -> None:
"""
Plots the variables of the Liesel sub-model that terminates in this variable.

Wraps :func:`~.viz.plot_vars`.

Parameters
----------
verbose
If ``True``, logs a message if unnamed variables or nodes are temporarily \
named for plotting.
show
Whether to show the plot in a new window.
save_path
Path to save the plot. If not provided, the plot will not be saved.
width
Width of the plot in inches.
height
Height of the plot in inches.
prog
Layout parameter. Available layouts: circo, dot (the default), fdp, neato, \
osage, patchwork, sfdp, twopi.
verbose
If ``True``, the message that will be logged if unnamed nodes are \
automatically named for plotting contains a list of the automatically \
assigned names.

See Also
--------
.Var.plot_vars : Plots the variables of the Liesel sub-model that terminates in
this variable.
.Var.plot_nodes : Plots the nodes of the Liesel sub-model that terminates in
this variable.
.Model.plot_vars : Plots the variables of a Liesel model.
.Model.plot_nodes : Plots the nodes of a Liesel model.
.viz.plot_vars : Plots the variables of a Liesel model.
.viz.plot_nodes : Plots the nodes of a Liesel model.
"""
return self._plot(
which="vars",
verbose=verbose,
show=show,
save_path=save_path,
width=width,
height=height,
prog=prog,
)

def plot_nodes(
self,
show: bool = True,
save_path: str | None | IO = None,
width: int = 14,
height: int = 10,
prog: Literal[
"dot", "circo", "fdp", "neato", "osage", "patchwork", "sfdp", "twopi"
] = "dot",
verbose: bool = False,
) -> None:
"""
Plots the nodes of the Liesel sub-model that terminates in this variable.

Wraps :func:`~.viz.plot_nodes`.

Parameters
----------
verbose
If ``True``, logs a message if unnamed variables or nodes are temporarily \
named for plotting.
show
Whether to show the plot in a new window.
save_path
Path to save the plot. If not provided, the plot will not be saved.
width
Width of the plot in inches.
height
Height of the plot in inches.
prog
Layout parameter. Available layouts: circo, dot (the default), fdp, neato, \
osage, patchwork, sfdp, twopi.
verbose
If ``True``, the message that will be logged if unnamed nodes are \
automatically named for plotting contains a list of the automatically \
assigned names.

See Also
--------
.Var.plot_vars : Plots the variables of the Liesel sub-model that terminates in
this variable.
.Var.plot_nodes : Plots the nodes of the Liesel sub-model that terminates in
this variable.
.Model.plot_vars : Plots the variables of a Liesel model.
.Model.plot_nodes : Plots the nodes of a Liesel model.
.viz.plot_vars : Plots the variables of a Liesel model.
.viz.plot_nodes : Plots the nodes of a Liesel model.
"""
return self._plot(
which="nodes",
verbose=verbose,
show=show,
save_path=save_path,
width=width,
height=height,
prog=prog,
)


def _transform_var_with_bijector_instance(var: Var, bijector_inst: jb.Bijector) -> Var:
if var.dist_node is None: # type: ignore
Expand Down
Loading