-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain_model.py
58 lines (48 loc) · 2.02 KB
/
main_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
from torch import nn, optim
import torch.nn.functional as F
from block.embed_block import embed
from block.TVA_block import TVA_block_att
from block.decoder_block import TVADE_block
from block.revin import RevIN
class DSFormer(nn.Module):
def __init__(self, Input_len, out_len, num_id, num_layer, dropout, muti_head, num_samp, IF_node):
"""
Input_len: History length
out_len:future length
num_id:number of variables
num_layer:number of layer. 1 or 2
muti_head:number of muti_head attention. 1 to 4
dropout:dropout. 0.15 to 0.3
num_samp:muti_head subsequence. 2 or 3
IF_node:Whether to use node embedding. True or False
"""
super(DSFormer, self).__init__()
if IF_node:
self.inputlen = 2 * Input_len // num_samp
else:
self.inputlen = Input_len // num_samp
### embed and encoder
self.RevIN = RevIN(num_id)
self.embed_layer = embed(Input_len,num_id,num_samp,IF_node)
self.encoder = TVA_block_att(self.inputlen,num_id,num_layer,dropout, muti_head,num_samp)
self.laynorm = nn.LayerNorm([self.inputlen])
### decorder
self.decoder = TVADE_block(self.inputlen, num_id, dropout, muti_head)
self.output = nn.Conv1d(in_channels = self.inputlen, out_channels=out_len, kernel_size=1)
def forward(self, x):
# Input [B,H,N]: B is batch size. N is the number of variables. H is the history length
# Output [B,L,N]: B is batch size. N is the number of variables. L is the future length
### embed
x = self.RevIN(x,'norm').transpose(-2,-1)
x_1, x_2 = self.embed_layer(x)
### encoder
x_1 = self.encoder(x_1)
x_2 = self.encoder(x_2)
x = x_1 + x_2
x = self.laynorm(x)
### decorder
x = self.decoder(x)
x = self.output(x.transpose(-2,-1))
x = self.RevIN(x, 'denorm')
return x