-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpc_algorithm.py
170 lines (151 loc) · 7.65 KB
/
pc_algorithm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import numpy as np
import networkx as nx
from itertools import combinations, permutations
from gsq.ci_tests import ci_test_bin, ci_test_dis
from gsq.gsq_testdata import bin_data, dis_data
def estimate_skeleton(data, independence_test, significance_level):
N_nodes = data.shape[1]
N_separation_neighbors = 0
graph = nx.complete_graph(N_nodes)
separation_sets_lst = [[set() for _ in range(N_nodes)] for _ in range(N_nodes)]
remove_edges_lst = []
while True:
continue_loop = False
for node in graph.nodes:
neighbors = list(graph.neighbors(node))
for neighbor in neighbors:
current_neighbors = neighbors[:]
current_neighbors.remove(neighbor)
if len(current_neighbors) >= N_separation_neighbors:
for possible_separation_lst in combinations(
current_neighbors, N_separation_neighbors
):
p_value = independence_test(
data, node, neighbor, set(possible_separation_lst)
)
if p_value > significance_level:
remove_edges_lst.append([node, neighbor])
separation_sets_lst[node][neighbor] |= set(
possible_separation_lst
)
separation_sets_lst[neighbor][node] |= set(
possible_separation_lst
)
continue_loop = True
graph.remove_edges_from(remove_edges_lst)
N_separation_neighbors += 1
if not continue_loop:
break
return graph, separation_sets_lst
def has_both_edges(dag, node_1, node_2):
return dag.has_edge(node_1, node_2) and dag.has_edge(node_2, node_1)
def has_any_edge(dag, node_1, node_2):
return dag.has_edge(node_1, node_2) or dag.has_edge(node_2, node_1)
def has_one_edge(dag, node_1, node_2):
return (
(dag.has_edge(node_1, node_2) and (not dag.has_edge(node_2, node_1)))
or (not dag.has_edge(node_1, node_2))
and dag.has_edge(node_2, node_1)
)
def has_no_edge(dag, node_1, node_2):
return (not dag.has_edge(node_1, node_2)) and (not dag.has_edge(node_2, node_1))
def estimate_cpdag(skeleton_graph, separation_sets_lst):
dag = skeleton_graph.to_directed()
for node_1, node_2 in combinations(skeleton_graph.nodes(), 2):
node_1_successors = set(dag.successors(node_1))
if node_2 in node_1_successors:
continue
node_2_successors = set(dag.successors(node_2))
if node_1 in node_2_successors:
continue
node_1_2_successors = node_1_successors & node_2_successors
for node_1_2_successor in node_1_2_successors:
if node_1_2_successor not in separation_sets_lst[node_1][node_2]:
if dag.has_edge(node_1_2_successor, node_1):
dag.remove_edge(node_1_2_successor, node_1)
if dag.has_edge(node_1_2_successor, node_2):
dag.remove_edge(node_1_2_successor, node_2)
# For all the combination of nodes node_1 and node_2, apply the following
# rules.
old_dag = dag.copy()
while True:
for node_1, node_2 in permutations(skeleton_graph.nodes(), 2):
# Rule 1: Orient node_1-node_2 into node_1->node_2 whenever there is an arrow node_1_predecessor->node_1
# such that node_1_predecessor and node_2 are nonadjacent.
#
# Check if node_1-node_2.
if has_both_edges(dag, node_1, node_2):
# Look all the predecessors of node_1.
for node_1_predecessor in dag.predecessors(node_1):
# Skip if there is an arrow node_1->node_1_predecessor.
if dag.has_edge(node_1, node_1_predecessor):
continue
# Skip if node_1_predecessor and node_2 are adjacent.
if has_any_edge(dag, node_1_predecessor, node_2):
continue
# Make node_1-node_2 into node_1->node_2
dag.remove_edge(node_2, node_1)
break
# Rule 2: Orient node_1-node_2 into node_1->node_2 whenever there is a chain
# node_1->node_1_predecessor->node_2.
#
# Check if node_1-node_2.
if has_both_edges(dag, node_1, node_2):
# Find nodes node_1_predecessor where node_1_predecessor is node_1->node_1_predecessor.
succs_i = set()
for node_1_predecessor in dag.successors(node_1):
if not dag.has_edge(node_1_predecessor, node_1):
succs_i.add(node_1_predecessor)
# Find nodes node_2 where node_2 is node_1_predecessor->node_2.
preds_j = set()
for node_1_predecessor in dag.predecessors(node_2):
if not dag.has_edge(node_2, node_1_predecessor):
preds_j.add(node_1_predecessor)
# Check if there is any node node_1_predecessor where node_1->node_1_predecessor->node_2.
if len(succs_i & preds_j) > 0:
# Make node_1-node_2 into node_1->node_2
dag.remove_edge(node_2, node_1)
# Rule 3: Orient node_1-node_2 into node_1->node_2 whenever there are two chains
# node_1-node_1_predecessor->node_2 and node_1-l->node_2 such that node_1_predecessor and l are nonadjacent.
#
# Check if node_1-node_2.
if has_both_edges(dag, node_1, node_2):
# Find nodes node_1_predecessor where node_1-node_1_predecessor.
adj_i = set()
for node_1_predecessor in dag.successors(node_1):
if dag.has_edge(node_1_predecessor, node_1):
adj_i.add(node_1_predecessor)
# For all the pairs of nodes in adj_i,
for node_1_predecessor, l in combinations(adj_i, 2):
# Skip if node_1_predecessor and l are adjacent.
if has_any_edge(dag, node_1_predecessor, l):
continue
# Skip if not node_1_predecessor->node_2.
if dag.has_edge(node_2, node_1_predecessor) or (
not dag.has_edge(node_1_predecessor, node_2)
):
continue
# Skip if not l->node_2.
if dag.has_edge(node_2, l) or (not dag.has_edge(l, node_2)):
continue
# Make node_1-node_2 into node_1->node_2.
dag.remove_edge(node_2, node_1)
break
# Rule 4: Orient node_1-node_2 into node_1->node_2 whenever there are two chains
# node_1-node_1_predecessor->l and node_1_predecessor->l->node_2 such that node_1_predecessor and node_2 are nonadjacent.
#
# However, this rule is not necessary when the PC-algorithm
# is used to estimate a DAG.
if nx.is_isomorphic(dag, old_dag):
break
old_dag = dag.copy()
return dag
data = np.array(bin_data).reshape((5000, 5))
graph, separation_sets_lst = estimate_skeleton(
data=data, independence_test=ci_test_bin, significance_level=0.01
)
graph = estimate_cpdag(skeleton_graph=graph, separation_sets_lst=separation_sets_lst)
graph_test = nx.DiGraph()
graph_test.add_nodes_from([0, 1, 2, 3, 4])
graph_test.add_edges_from([(0, 1), (2, 3), (3, 2), (3, 1), (2, 4), (4, 2), (4, 1)])
assert(nx.is_isomorphic(graph, graph_test))