-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_compute_matching.py
98 lines (74 loc) · 4.13 KB
/
test_compute_matching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
from bipartite_match import CurrentElems, toy_e_weights_type
from compute_matching import compute_matching, jitter_matrix, weight_matrix
def unambiguous_matching():
currpool = CurrentElems([[torch.tensor(2), 0, 5], [torch.tensor(1), 0, 5], [torch.tensor(2), 0, 5]],
[[torch.tensor(0), 0, 5]])
e_weights = toy_e_weights_type()
correct_matching = torch.tensor([[0.0],
[1.0],
[0.0]])
return currpool, e_weights, correct_matching
def test_compute_matching_noweights():
currpool, e_weights_type, correct_matching = unambiguous_matching()
resulting_match, e_weights = compute_matching(currpool, torch.zeros(5), e_weights_type)
assert torch.allclose(resulting_match, correct_matching)
def ambiguous_matching():
currpool = CurrentElems([[torch.tensor(2), 0, 5], [torch.tensor(1), 0, 5], [torch.tensor(2), 0, 5]],
[[torch.tensor(0), 0, 5], [torch.tensor(0), 0, 5]])
e_weights = toy_e_weights_type()
one_correct_matching = torch.tensor([[0.0, 0.0],
[1.0, 0.0],
[0.0, 0.0]])
return currpool, e_weights, one_correct_matching
def test_compute_ambiguous():
# ties ought to be broken deterministically, by ordering (should jitter edge weights in this way)
currpool, e_weights_type, one_correct_matching = ambiguous_matching()
resulting_match, e_weights = compute_matching(currpool, torch.zeros(5), e_weights_type)
assert torch.allclose(resulting_match, one_correct_matching, atol=1e-4)
def test_weight_matrix():
currpool = CurrentElems([[torch.tensor(2), 0, 5], [torch.tensor(1), 0, 5], [torch.tensor(2), 0, 5]],
[[torch.tensor(0), 0, 5]])
e_weights = toy_e_weights_type()
result_weights = weight_matrix(currpool.lhs, currpool.rhs, e_weights)
assert torch.allclose(result_weights, torch.tensor([[-100.0],
[3.0],
[-100.0]]))
def test_potentials():
currpool = CurrentElems([[torch.tensor(1), 0, 5], [torch.tensor(2), 0,5]],
[[torch.tensor(2), 0, 5]])
e_weights = toy_e_weights_type()
potentials = torch.tensor([0.0,1.0,0.0,0.0,0.0])
desired_match = torch.tensor([[0.0],
[1.0]])
result_match, e_weights = compute_matching(currpool, potentials, e_weights)
assert torch.allclose(result_match, desired_match, atol=1e-6)
def test_opposite_potentials():
currpool = CurrentElems([[torch.tensor(1), 0, 5], [torch.tensor(2), 0, 5]],
[[torch.tensor(2), 0, 5]])
e_weights = toy_e_weights_type()
potentials = torch.tensor([0.0, -1.0, 0.0, 0.0, 0.0])
desired_match = torch.tensor([[1.0],
[0.0]])
result_match, e_weights = compute_matching(currpool, potentials, e_weights)
assert torch.allclose(result_match, desired_match, atol=1e-6)
def test_zero_potentials():
currpool = CurrentElems([[torch.tensor(1), 0, 5], [torch.tensor(2), 0,5]],
[[torch.tensor(2), 0, 5]])
e_weights = toy_e_weights_type()
potentials = torch.tensor([0.0,0.0,0.0,0.0,0.0])
desired_match = torch.tensor([[1.0],
[0.0]])
result_match, e_weights = compute_matching(currpool, potentials, e_weights)
assert torch.allclose(result_match, desired_match, atol=1e-6)
def dont_test_tiebreak():
# this is a failing test that reveals a fractional matching
currpool = CurrentElems([[torch.tensor(1), 0, 5],[torch.tensor(1), 0, 5], [torch.tensor(1), 0, 5]],
[[torch.tensor(1), 0, 5], [torch.tensor(1), 0, 5]])
e_weights = toy_e_weights_type()
e_weights_full = weight_matrix(currpool.lhs, currpool.rhs, e_weights)
desired_match = torch.tensor([[1.0,0.0],
[0.0,1.0],
[0.0,0.0]])
result_match, e_weights = compute_matching(currpool, torch.zeros(5), e_weights)
assert torch.allclose(result_match, desired_match, atol=1e-6)