-
Notifications
You must be signed in to change notification settings - Fork 0
/
episode_evals.py
113 lines (86 loc) · 4.19 KB
/
episode_evals.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from abc import ABC, abstractmethod
from typing import Tuple
from agents import CausalAgent
import networkx as nx
def directed_shd(predicted: nx.DiGraph, target: nx.DiGraph) -> int:
assert len(predicted.nodes) == len(target.nodes), 'Graphs need to have the same amount of nodes'
# this corresponds to the SHD (structural Hamming distance or SHD (Tsamardinos et al., 2006;
# with the difference that we consider undirected edges as bidirected edges and flip is considered
# as 2 errors instead of 1
differences = 0
for node in predicted.adj:
# check which edges are too much in the predicted graph
for parent in predicted.adj[node]:
if not parent.upper() in target.adj[node.upper()]:
differences += 1
# check which edges are missing in the predicted graph
for parent in target.adj[node.upper()]:
if not parent.lower() in predicted.adj[node]:
differences += 1
return differences
class EvalFunc(ABC):
"""Interface for evaluation function for the RL learning"""
effect_threshold: float
steps_this_episode: int
def __init__(self, agent: CausalAgent):
super(EvalFunc, self).__init__()
self.agent = agent
self.steps_this_episode = 0
@abstractmethod
def evaluate_step(self, action_successful: bool, allow_unsuccessful_actions: bool = True) -> Tuple[bool, float]:
"""
Calculates for each step whether the episode is done and what the reward for that step is
:param action_successful: whether the action which was taken was executed successfully
:param allow_unsuccessful_actions: Whether unsuccessful actions are allowed
:return: Whether the episode is done and the reward of the current step
"""
raise NotImplementedError
@abstractmethod
def _eval_model(self):
"""Defines how the model is evaluated"""
raise NotImplementedError
# ----------------------------------------------------------------------------
# Functions evaluating the model based on the structural difference to the target graph
class StructureEvalFunc(EvalFunc):
def __init__(self, agent: CausalAgent, graph: nx.DiGraph):
super(StructureEvalFunc, self).__init__(agent)
self.compare_graph = graph
def set_compare_graph(self, graph: nx.DiGraph):
self.compare_graph = graph
def _eval_model(self):
diff = directed_shd(self.agent.causal_model, self.compare_graph)
return -diff
@abstractmethod
def evaluate_step(self, action_successful: bool, allow_unsuccessful_actions: bool = True) -> Tuple[bool, float]:
"""
Calculates for each step whether the episode is done and what the reward for that step is
:param action_successful: whether the action which was taken was executed successfully
:param allow_unsuccessful_actions: Whether unsuccessful actions are allowed
:return: Whether the episode is done and the reward of the current step
"""
raise NotImplementedError
class FixedLengthStructEpisode(StructureEvalFunc):
def __init__(self, agent: CausalAgent, graph: nx.DiGraph, ep_length: int):
super(FixedLengthStructEpisode, self).__init__(agent, graph)
self.episode_length = ep_length
def evaluate_step(self, action_successful: bool, allow_unsuccessful_actions: bool = True) -> Tuple[bool, float]:
self.steps_this_episode += 1
done = False
# Evaluate when the episode length is reached
if self.steps_this_episode >= self.episode_length:
done = True
self.steps_this_episode = 0
reward = self._eval_model()
elif not action_successful and not allow_unsuccessful_actions: # illegal action was taken
reward = -1
else:
reward = 0
return done, reward
class NoEval(EvalFunc):
"""Does nothing. Used when applying the policy so there are no prints and rewards"""
def _eval_model(self):
pass
def __init__(self):
super(NoEval, self).__init__(None)
def evaluate_step(self, action_successful: bool, allow_unsuccessful_actions: bool = True) -> Tuple[bool, float]:
return False, 0.0