-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtests.py
97 lines (71 loc) · 2.91 KB
/
tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from itertools import product
import pytest
import torch
import nilm_disaggregation_nets as nets
# ### NEURAL NILM ###
in_channels = [i for i in range(1, 6)]
out_channels = [i for i in range(1, 6)]
sequence_length = [i * 10 for i in range(2, 10)]
batch_size = [i * 10 for i in range(1, 6)]
@pytest.mark.parametrize(
"in_channels,out_channels,sequence_length,batch_size",
list(product(in_channels, out_channels, sequence_length, batch_size)),
)
def test_NeuralNilmDAE(in_channels, out_channels, sequence_length, batch_size):
net = nets.NeuralNilmDAE(
sequence_length=sequence_length,
in_channels=in_channels,
out_channels=out_channels,
)
x = torch.empty((batch_size, in_channels, sequence_length))
y = net(x)
assert list(y.shape) == [batch_size, out_channels, sequence_length]
@pytest.mark.parametrize(
"in_channels,out_channels,sequence_length,batch_size",
list(product(in_channels, out_channels, sequence_length, batch_size)),
)
def test_NeuralNilmBiLSTM(in_channels, out_channels, sequence_length, batch_size):
net = nets.NeuralNilmBiLSTM(
sequence_length=sequence_length,
in_channels=in_channels,
out_channels=out_channels,
)
x = torch.empty((batch_size, in_channels, sequence_length))
y = net(x)
assert list(y.shape) == [batch_size, out_channels, sequence_length]
### SEQ2SEQ SEQ2POINT ###
in_channels = [i for i in range(1, 6)]
out_channels = [i for i in range(1, 6)]
input_length = [i * 10 for i in range(3, 10)] # the minimun is 29
batch_size = [i * 10 for i in range(1, 6)]
@pytest.mark.parametrize(
"in_channels,input_length,batch_size",
list(product(in_channels, input_length, batch_size)),
)
def test_SeqToBase(in_channels, input_length, batch_size):
net = nets.SeqToBase(input_length=input_length, in_channels=in_channels)
x = torch.empty((batch_size, in_channels, input_length))
y = net(x)
assert list(y.shape) == [batch_size, 1024]
@pytest.mark.parametrize(
"in_channels,out_channels,input_length,batch_size",
list(product(in_channels, out_channels, input_length, batch_size)),
)
def test_SeqToSeq(in_channels, out_channels, input_length, batch_size):
net = nets.SeqToSeq(
input_length=input_length, in_channels=in_channels, out_channels=out_channels,
)
x = torch.empty((batch_size, in_channels, input_length))
y = net(x)
assert list(y.shape) == [batch_size, out_channels, input_length]
@pytest.mark.parametrize(
"in_channels,out_channels,input_length,batch_size",
list(product(in_channels, out_channels, input_length, batch_size)),
)
def test_SeqToPoint(in_channels, out_channels, input_length, batch_size):
net = nets.SeqToPoint(
input_length=input_length, in_channels=in_channels, out_channels=out_channels,
)
x = torch.empty((batch_size, in_channels, input_length))
y = net(x)
assert list(y.shape) == [batch_size, out_channels, 1]