-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathVRM_loss_functions.py
56 lines (41 loc) · 1.37 KB
/
VRM_loss_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
def calculate_VRM_loss_on_sequence(deepDFA, sym_seq, rew_seq):
return 0
def calculate_VRM_loss_on():
return 0
def sat_current_output(r_pred, r_target):
r_target = r_target.unsqueeze(1)
sat = torch.gather(r_pred, 1, r_target)
return sat.squeeze(1)
def sat_next_transition_batch(s_pred_batch, r_target_batch, deep_dfa):
batch_size = s_pred_batch.size()[0]
sat_batch = torch.zeros((batch_size))
for i in range(batch_size):
sat = sat_next_transition(s_pred_batch[i], r_target_batch[i], deep_dfa)
sat_batch[i] = sat
return sat_batch
def sat_next_transition(s_pred, r_target, deep_dfa):
if r_target == 0:
return torch.ones((1))
next_action = torch.eye(deep_dfa.numb_of_actions)
s_pred = s_pred.repeat(deep_dfa.numb_of_actions,1)
_, next_rew = deep_dfa.step(s_pred, next_action, 1.0)
if r_target == deep_dfa.numb_of_rewards -1:
sat = forall(next_rew[:,deep_dfa.numb_of_rewards - 1], 3)
else:
sat = exists(next_rew[:, r_target - 1], 3)
return sat
def exists(tensor, p):
tensor = torch.pow(tensor, p)
sat= tensor.mean()
sat=torch.pow(sat, 1/p)
return sat
def forall(tensor, p):
tensor = 1 - tensor
tensor = torch.pow(tensor, p)
sat = tensor.mean()
sat = torch.pow(sat, 1/p)
sat = 1 - sat
return sat
def conjunction(a,b):
return a*b