Skip to content

Commit

Permalink
[ENH] Add the ability to check the validity of a PAG (#100)
Browse files Browse the repository at this point in the history
* Added a _proper_pag function
* Added some test for legal_pag
* [pre-commit.ci] auto fixes from pre-commit.com hooks
* Update examples/intro/checking_validity_of_a_pag.py

---------

Signed-off-by: Aryan Roy <[email protected]>
Co-authored-by: Adam Li <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 15, 2024
1 parent 71a18e7 commit 31ad037
Show file tree
Hide file tree
Showing 10 changed files with 480 additions and 21 deletions.
21 changes: 14 additions & 7 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ commands:
- run:
name: Check-skip
command: |
if [ ! -d "sktree" ]; then
if [ ! -d "pywhy_graphs" ]; then
echo "Build was not run due to skip, exiting job ${CIRCLE_JOB} for PR ${CIRCLE_PULL_REQUEST}."
circleci-agent step halt;
fi
Expand Down Expand Up @@ -54,7 +54,7 @@ commands:
echo ${CI_PULL_REQUEST//*pull\//} | tee merge.txt
if [[ $(cat merge.txt) != "" ]]; then
echo "Merging $(cat merge.txt)";
git remote add upstream https://github.com/neurodata/scikit-tree.git;
git remote add upstream https://github.com/py-why/pywhy-graphs.git;
git pull --ff-only upstream "refs/pull/$(cat merge.txt)/merge";
git fetch upstream main;
fi
Expand All @@ -64,7 +64,7 @@ jobs:
docker:
# CircleCI maintains a library of pre-built images
# documented at https://circleci.com/doc/2.0/circleci-images/
- image: cimg/python:3.9
- image: cimg/python:3.11
steps:
- checkout
- check-skip
Expand Down Expand Up @@ -96,18 +96,25 @@ jobs:
name: Setup torch for pgmpy
command: |
sudo apt-get install nvidia-cuda-toolkit nvidia-cuda-toolkit-gcc
- run:
name: Install dodiscover
command: |
git clone https://github.com/py-why/dodiscover.git
cd dodiscover
python -m pip install .
- run:
name: Check installation
command: |
python -c "import pywhy_graphs;"
python -c "import numpy; numpy.show_config()"
python -c "import dodiscover;"
LIBGL_DEBUG=verbose python -c "import matplotlib.pyplot as plt; plt.figure()"
# dowhy currently requires an older version of numpy
- run:
name: Temporary Hack for numpy
command: |
python -m pip install numpy==1.22.0
# - run:
# name: Temporary Hack for numpy
# command: |
# python -m pip install numpy==1.22.0

- run:
name: Build documentation
Expand Down
17 changes: 16 additions & 1 deletion .github/workflows/circle_artifacts.yml
Original file line number Diff line number Diff line change
@@ -1,14 +1,29 @@
name: CircleCI artifacts redirector
on: [status]

# Restrict the permissions granted to the use of secrets.GITHUB_TOKEN in this
# github actions workflow:
# https://docs.github.com/en/actions/security-guides/automatic-token-authentication
permissions: read-all

jobs:
circleci_artifacts_redirector_job:
if: "${{ startsWith(github.event.context, 'ci/circleci: build_doc') }}"
runs-on: ubuntu-20.04
if: "github.repository == 'py-why/pywhy-graphs' && github.event.context == 'ci/circleci: build_doc'"
permissions:
statuses: write
name: Run CircleCI artifacts redirector
steps:
- name: GitHub Action step
uses: larsoner/circleci-artifacts-redirector-action@master
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
api-token: ${{ secrets.CIRCLECI_TOKEN }}
artifact-path: 0/dev/index.html
circleci-jobs: build_doc
job-title: Check the rendered docs here!

- name: Check the URL
if: github.event.status != 'pending'
run: |
curl --fail ${{ steps.step1.outputs.url }} | grep $GITHUB_SHA
29 changes: 19 additions & 10 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,26 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
architecture: "x64"

- name: Setup torch for pgmpy
if: "matrix.os == 'ubuntu'"
shell: bash
run: |
sudo apt-get update
sudo apt-get install nvidia-cuda-toolkit nvidia-cuda-toolkit-gcc
- name: Install packages via pip
run: |
pip install --upgrade pip
pip install numpy scipy networkx statsmodels
pip install .[test]
python -m pip install --upgrade pip
python -m pip install numpy scipy networkx statsmodels
python -m pip install .[test]
- name: Install DoDiscover (main)
run: |
git clone https://github.com/py-why/dodiscover.git
cd dodiscover
python -m pip install .
- name: Install Networkx (main)
if: "matrix.networkx == 'main'"
run: |
Expand All @@ -131,16 +146,10 @@ jobs:
pip install .[default]
# pip install --progress-bar off git+https://github.com/networkx/networkx
- name: Setup torch for pgmpy
if: "matrix.os == 'ubuntu'"
shell: bash
run: |
sudo apt-get update
sudo apt-get install nvidia-cuda-toolkit nvidia-cuda-toolkit-gcc
- name: Run pytest # headless via Xvfb on linux
run: |
pytest --cov pywhy_graphs ./pywhy_graphs
- name: Upload coverage stats to codecov
if: ${{ matrix.os == 'ubuntu' && matrix.python-version == '3.11' && matrix.networkx == 'stable' }}
uses: codecov/codecov-action@v4
Expand Down
4 changes: 4 additions & 0 deletions doc/reference/algorithms/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ Core Algorithms
possible_descendants
discriminating_path
is_definite_noncollider
valid_pag
mag_to_pag
pag_to_mag
check_pag_definition

.. currentmodule:: pywhy_graphs.networkx

Expand Down
1 change: 1 addition & 0 deletions doc/whats_new/v0.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Changelog
- |Feature| Implement a suite of functions for finding and checking semi-directed paths on a mixed-edge graph, by `Adam Li`_ (:pr:`101`)
- |Feature| Implement functions for converting between a DAG and PDAG and CPDAG for generating consistent extensions of a CPDAG for example. These functions are :func:`pywhy_graphs.algorithms.pdag_to_cpdag`, :func:`pywhy_graphs.algorithms.pdag_to_dag` and :func:`pywhy_graphs.algorithms.dag_to_cpdag`, by `Adam Li`_ (:pr:`102`)
- |API| Remove poetry based setup, by `Adam Li`_ (:pr:`110`)
- |Feature| Implement and test function to validate PAG, by `Aryan Roy`_ (:pr:`100`)

Code and Documentation Contributors
-----------------------------------
Expand Down
97 changes: 97 additions & 0 deletions examples/intro/checking_validity_of_a_pag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
===========================
On PAGs and their validity
===========================
A PAG or a Partial Ancestral Graph is a type of mixed edge
graph that can represent, in a single graph, the causal relationship
between several nodes as defined by an equivalence class of MAGs.
PAGs account for possible unobserved confounding and selection bias
in the underlying equivalence class of SCMs.
Another way to understand this is that PAGs encode conditional independence
constraints stemming from Causal Graphs. Since these constraints do not lead to a
unique graph, a PAG, in essence, represents a class of graphs that encode
the same conditional independence constraints.
PAGs model this relationship by displaying all common edge marks (tail and arrowhead) shared
by all members in the equivalence class and displaying circle endpoints for those marks
that are not common. That is, a circular endpoint (``*-o``) can represent both a directed
(``*->``) and tail (``*—``) endpoint in causal graphs within the equivalence class.
More details on PAGs can be found at :footcite:`Zhang2008`.
"""

import pywhy_graphs
from pywhy_graphs.viz import draw
from pywhy_graphs import PAG

try:
from dodiscover import FCI, make_context
from dodiscover.ci import Oracle
from dodiscover.constraint.utils import dummy_sample
except ImportError as e:
raise ImportError("The 'dodiscover' package is required to convert a MAG to a PAG.")


# %%
# PAGs in pywhy-graphs
# ---------------------------
# Constructing a PAG in pywhy-graphs is an easy task since
# the library provides a separate class for this purpose.
# True to the definition of PAGs, the class can contain
# directed edges, bidirected edges, undirected edges and
# cicle edges. To illustrate this, we construct an example PAG
# as described in :footcite:`Zhang2008`, figure 4:

pag = PAG()
pag.add_edge("I", "S", pag.directed_edge_name)
pag.add_edge("G", "S", pag.directed_edge_name)
pag.add_edge("G", "L", pag.directed_edge_name)
pag.add_edge("S", "L", pag.directed_edge_name)
pag.add_edge("PSH", "S", pag.directed_edge_name)
pag.add_edge("S", "PSH", pag.circle_edge_name)
pag.add_edge("S", "G", pag.circle_edge_name)
pag.add_edge("S", "I", pag.circle_edge_name)


# Finally, the graph looks like this:
dot_graph = draw(pag)
dot_graph.render(outfile="valid_pag.png", view=True)


# %%
# Validity of a PAG
# ---------------------------
# For a PAG to be valid, it must represent a valid
# equivalent class of MAGs. This can be verified by
# turning the PAG into an MAG and then checking the
# validity of the MAG.
# Theorem 2 in :footcite:`Zhang2008` provides a method for checking the validity of a PAG.
# To check if the constructed PAG is a valid one in pywhy-graphs, we can simply do:


# returns True
print(pywhy_graphs.valid_pag(pag))

# %%
# If we want to test whether this algorithm is working correctly or not, we can change
# a single mark in the graph such that the PAG. By removing a circle edge, we are removing
# the representation of multiple marks as encoded by the different MAGs this PAG represents.
# In this specific case, by removing the circle endpoint ``S *-o I``, we are saying that ``S``
# directly causes ``I``. However, there is no way of determining this using the FCI logical rules.
# One would not be able to determine that the adjacency is due to a direct
# causal relationship (directed edge), confounded relationship (bidirected edge), or an inducing path
# relationship. As such, the resulting graph is no longer a valid PAG.

pag.remove_edge("S", "I", pag.circle_edge_name)

# returns False
print(pywhy_graphs.valid_pag(pag))

# %%
# References
# ----------
# .. footbibliography::
4 changes: 2 additions & 2 deletions examples/intro/intro_causal_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,10 @@ def clone(self):
# Using the graph, we can explore d-separation statements, which by the Markov
# condition, imply conditional independences.
# For example, 'z' is d-separated from 'x' because of the collider at 'y'
print(f"'z' is d-separated from 'x': {nx.is_d_separator(G, {'z'}, {'x'}, set())}")
print(f"'z' is d-separated from 'x': {nx.d_separated(G, {'z'}, {'x'}, set())}")

# Conditioning on the collider, opens up the path
print(f"'z' is d-separated from 'x' given 'y': {nx.is_d_separator(G, {'z'}, {'x'}, {'y'})}")
print(f"'z' is d-separated from 'x' given 'y': {nx.d_separated(G, {'z'}, {'x'}, {'y'})}")

# %%
# Acyclic Directed Mixed Graphs (ADMG)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ doc = [
'graphviz',
'pygraphviz',
'pgmpy',
'dowhy',
]
style = [
"pre-commit",
Expand Down
Loading

0 comments on commit 31ad037

Please sign in to comment.