From 08b89427e4fc5879901e0a69e1f777110b20ff3b Mon Sep 17 00:00:00 2001 From: guntherxing Date: Thu, 2 Nov 2023 10:53:26 +0800 Subject: [PATCH] [bug-fix] when use detailed freq like 5min TimeFeatureEmbedding will crush Signed-off-by: guntherxing --- layers/Embed.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/layers/Embed.py b/layers/Embed.py index 977e2556..ef5ed0c7 100644 --- a/layers/Embed.py +++ b/layers/Embed.py @@ -97,9 +97,20 @@ class TimeFeatureEmbedding(nn.Module): def __init__(self, d_model, embed_type='timeF', freq='h'): super(TimeFeatureEmbedding, self).__init__() - freq_map = {'h': 4, 't': 5, 's': 6, - 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3} - d_inp = freq_map[freq] + def freq_to_dim(freq): + while freq[0].isdigit(): + freq = freq[1:] + freq = freq.lower() + if freq == "min": + freq = 't' + elif freq == "A": + freq = 'y' + + freq_map = {'h': 4, 't': 5, 's': 6, + 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3} + return freq_map[freq] + + d_inp = freq_to_dim(freq) self.embed = nn.Linear(d_inp, d_model, bias=False) def forward(self, x):