Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A new model called LMSAutoTSF #622

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions exp/exp_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from models import Autoformer, Transformer, TimesNet, Nonstationary_Transformer, DLinear, FEDformer, \
Informer, LightTS, Reformer, ETSformer, Pyraformer, PatchTST, MICN, Crossformer, FiLM, iTransformer, \
Koopa, TiDE, FreTS, TimeMixer, TSMixer, SegRNN, MambaSimple, TemporalFusionTransformer, SCINet, PAttn, TimeXer
Koopa, TiDE, FreTS, TimeMixer, TSMixer, SegRNN, MambaSimple, TemporalFusionTransformer, SCINet, PAttn, TimeXer, LMSAutoTSF


class Exp_Basic(object):
Expand Down Expand Up @@ -35,7 +35,8 @@ def __init__(self, args):
'TemporalFusionTransformer': TemporalFusionTransformer,
"SCINet": SCINet,
'PAttn': PAttn,
'TimeXer': TimeXer
'TimeXer': TimeXer,
'LMSAutoTSF': LMSAutoTSF
}
if args.model == 'Mamba':
print('Please make sure you have successfully installed mamba_ssm')
Expand Down
162 changes: 162 additions & 0 deletions models/LMSAutoTSF.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@

import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

from layers.StandardNorm import Normalize


def compute_lagged_difference(x, lag=1):
lagged_x = torch.roll(x, shifts=lag, dims=1)
diff_x = x - lagged_x
diff_x[:, :lag, :] = x[:, :lag, :]
return diff_x


class Encoder(nn.Module):
def __init__(self, configs, seq_len, pred_len):
super(Encoder, self).__init__()
self.seq_len = seq_len
self.pred_len = pred_len
self.feature_dim = configs.enc_in
self.channel_independence = configs.channel_independence

self.linear_final = nn.Linear(self.seq_len, self.pred_len)

self.temporal = nn.Sequential(
nn.Linear(self.seq_len, configs.d_model),
nn.ReLU(),
nn.Linear(configs.d_model, self.seq_len),
nn.Dropout(configs.dropout)
)

if not self.channel_independence:
self.channel = nn.Sequential(
nn.Linear(self.feature_dim, configs.d_model),
nn.ReLU(),
nn.Linear(configs.d_model, self.feature_dim),
nn.Dropout(configs.dropout)
)

def forward(self, x_enc):

# Temporal and channel processing
x_temp = self.temporal(x_enc.permute(0, 2, 1)).permute(0, 2, 1)
x_temp = torch.multiply(x_temp, compute_lagged_difference(x_enc))
x = x_enc + x_temp

if not self.channel_independence:
x = x + self.channel(x_temp)

return self.linear_final(x.permute(0, 2, 1)).permute(0, 2, 1)


class Model(nn.Module):
def __init__(self, configs):
super(Model, self).__init__()

self.task_name = configs.task_name
self.seq_len = configs.seq_len
self.label_len = configs.label_len
self.pred_len = configs.pred_len
self.feature_dim = configs.enc_in
self.d_model = configs.d_model
self.down_sampling_layers = 3
self.down_sampling_window = 2

sequence_list = [1]
current = 2
for _ in range(1, self.down_sampling_layers+1):
sequence_list.append(current)
current *= 2

num_scales = len(sequence_list)

self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

self.cutoff_frequencies = nn.Parameter(torch.ones(num_scales, self.feature_dim, device=self.device) * torch.tensor(0.2))
self.stepness = nn.Parameter(torch.ones(num_scales, self.feature_dim, device=self.device) * torch.tensor(10))

self.encoder_Seasonal = torch.nn.ModuleList([Encoder(configs, self.seq_len//i, self.pred_len) for i in sequence_list])
self.encoder_Trend = torch.nn.ModuleList([Encoder(configs, self.seq_len//i, self.pred_len) for i in sequence_list])

self.normalize_layer = Normalize(configs.enc_in, affine=True, non_norm=True if configs.use_norm == 0 else False)
self.projection = nn.Linear(self.pred_len * num_scales, self.pred_len)


def __multi_scale_process_inputs(self, x_enc, x_mark_enc):

down_pool = torch.nn.AvgPool1d(self.down_sampling_window)
# B,T,C -> B,C,T
x_enc = x_enc.permute(0, 2, 1)

x_enc_ori = x_enc
x_mark_enc_mark_ori = x_mark_enc

x_enc_sampling_list = []
x_mark_sampling_list = []
x_enc_sampling_list.append(x_enc.permute(0, 2, 1))
x_mark_sampling_list.append(x_mark_enc)

for i in range(self.down_sampling_layers):
x_enc_sampling = down_pool(x_enc_ori)

x_enc_sampling_list.append(x_enc_sampling.permute(0, 2, 1))
x_enc_ori = x_enc_sampling

if x_mark_enc is not None:
x_mark_sampling_list.append(x_mark_enc_mark_ori[:, ::self.down_sampling_window, :])
x_mark_enc_mark_ori = x_mark_enc_mark_ori[:, ::self.down_sampling_window, :]

x_enc = x_enc_sampling_list
x_mark_enc = x_mark_sampling_list if x_mark_enc is not None else None

return x_enc, x_mark_enc


def low_pass_filter(self, x_freq, seq_len, cutoff_frequency, stepness):
freqs = torch.fft.fftfreq(seq_len, d=1.0).to(x_freq.device)
mask = torch.sigmoid(-(freqs.unsqueeze(-1) - cutoff_frequency) * stepness) # Apply different cutoff for each feature
mask = mask.to(x_freq.device)
x_freq_filtered = x_freq * mask
return x_freq_filtered

def high_pass_filter(self, x_freq, seq_len, cutoff_frequency, stepness):
freqs = torch.fft.fftfreq(seq_len, d=1.0).to(x_freq.device)
mask = torch.sigmoid((freqs.unsqueeze(-1) - cutoff_frequency) * stepness) # Apply different cutoff for each feature
mask = mask.to(x_freq.device)
x_freq_filtered = x_freq * mask
return x_freq_filtered


def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):

x_enc = self.normalize_layer(x_enc, 'norm')

output_list = []
# ******************* SCALED INPUTS *******************************
x_enc_list, x_mark_enc_list = self.__multi_scale_process_inputs(x_enc, x_mark_enc)
for i, x in zip(range(len(x_enc_list)), x_enc_list):
seq_len = x.shape[1]
# Frequency domain processing
x_freq = torch.fft.fft(x, dim=1)

x_freq_low = self.low_pass_filter(x_freq, seq_len, self.cutoff_frequencies[i], self.stepness[i])
x_freq_high = self.high_pass_filter(x_freq, seq_len, self.cutoff_frequencies[i], self.stepness[i])

x_low = torch.fft.ifft(x_freq_low, dim=1).real
x_high = torch.fft.ifft(x_freq_high, dim=1).real

seasonal_output = self.encoder_Seasonal[i](x_high)
trend_output = self.encoder_Trend[i](x_low)
output = seasonal_output + trend_output

output_list.append(output)


output = torch.cat(output_list, dim=1)
output = self.projection(output.permute(0,2,1)).permute(0,2,1)

output = self.normalize_layer(output, 'denorm')
return output
89 changes: 89 additions & 0 deletions scripts/long_term_forecast/ECL_script/LMSAutoTSF.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
model_name=LMSAutoTSF

python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/electricity/ \
--data_path electricity.csv \
--model_id ECL_96_96 \
--model $model_name \
--channel_independence 0 \
--data custom \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len 96 \
--e_layers 2 \
--d_layers 1 \
--factor 3 \
--enc_in 321 \
--dec_in 321 \
--c_out 321 \
--des 'Exp' \
--itr 1

python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/electricity/ \
--data_path electricity.csv \
--model_id ECL_96_192 \
--model $model_name \
--channel_independence 0 \
--data custom \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len 192 \
--e_layers 2 \
--d_layers 1 \
--factor 3 \
--enc_in 321 \
--dec_in 321 \
--c_out 321 \
--des 'Exp' \
--itr 1

python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/electricity/ \
--data_path electricity.csv \
--model_id ECL_96_336 \
--model $model_name \
--channel_independence 0 \
--data custom \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len 336 \
--e_layers 2 \
--d_layers 1 \
--factor 3 \
--enc_in 321 \
--dec_in 321 \
--c_out 321 \
--des 'Exp' \
--itr 1

python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/electricity/ \
--data_path electricity.csv \
--model_id ECL_96_720 \
--model $model_name \
--channel_independence 0 \
--data custom \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len 720 \
--e_layers 2 \
--d_layers 1 \
--factor 3 \
--enc_in 321 \
--dec_in 321 \
--c_out 321 \
--des 'Exp' \
--itr 1
90 changes: 90 additions & 0 deletions scripts/long_term_forecast/ETT_script/LMSAutoTSF_ETTh1.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@

model_name=LMSAutoTSF

python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/ETT-small/ \
--data_path ETTh1.csv \
--model_id ETTh1_96_96 \
--model $model_name \
--channel_independence 0 \
--data ETTh1 \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len 96 \
--e_layers 2 \
--d_layers 1 \
--factor 3 \
--enc_in 7 \
--dec_in 7 \
--c_out 7 \
--des 'Exp' \
--itr 1

python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/ETT-small/ \
--data_path ETTh1.csv \
--model_id ETTh1_96_192 \
--model $model_name \
--channel_independence 0 \
--data ETTh1 \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len 192 \
--e_layers 2 \
--d_layers 1 \
--factor 3 \
--enc_in 7 \
--dec_in 7 \
--c_out 7 \
--des 'Exp' \
--itr 1

python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/ETT-small/ \
--data_path ETTh1.csv \
--model_id ETTh1_96_336 \
--model $model_name \
--channel_independence 0 \
--data ETTh1 \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len 336 \
--e_layers 2 \
--d_layers 1 \
--factor 3 \
--enc_in 7 \
--dec_in 7 \
--c_out 7 \
--des 'Exp' \
--itr 1

python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/ETT-small/ \
--data_path ETTh1.csv \
--model_id ETTh1_96_720 \
--model $model_name \
--channel_independence 0 \
--data ETTh1 \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len 720 \
--e_layers 2 \
--d_layers 1 \
--factor 3 \
--enc_in 7 \
--dec_in 7 \
--c_out 7 \
--des 'Exp' \
--itr 1
Loading