Skip to content

Commit

Permalink
Merge pull request #424 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Enable tuning new added models
  • Loading branch information
WenjieDu authored May 27, 2024
2 parents f39bfe6 + 8e25601 commit 2668891
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 16 deletions.
44 changes: 42 additions & 2 deletions pypots/cli/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,35 @@
from .base import BaseCommand
from .utils import load_package_from_path
from ..classification import BRITS as BRITS_classification
from ..classification import Raindrop, GRUD
from ..classification import Raindrop
from ..clustering import CRLI, VaDER
from ..data.saving.h5 import load_dict_from_h5
from ..imputation import SAITS, Transformer, CSDI, USGAN, GPVAE, MRNN, BRITS, TimesNet
from ..imputation import (
SAITS,
FreTS,
Koopa,
iTransformer,
Crossformer,
TimesNet,
PatchTST,
ETSformer,
MICN,
DLinear,
SCINet,
NonstationaryTransformer,
FiLM,
Pyraformer,
Autoformer,
CSDI,
Informer,
USGAN,
StemGNN,
GPVAE,
MRNN,
BRITS,
GRUD,
Transformer,
)
from ..optim import Adam
from ..utils.logging import logger
from ..utils.random import set_random_seed
Expand All @@ -33,7 +58,22 @@
NN_MODELS = {
# imputation models
"pypots.imputation.SAITS": SAITS,
"pypots.imputation.iTransformer": iTransformer,
"pypots.imputation.Transformer": Transformer,
"pypots.imputation.FreTS": FreTS,
"pypots.imputation.Koopa": Koopa,
"pypots.imputation.Crossformer": Crossformer,
"pypots.imputation.PatchTST": PatchTST,
"pypots.imputation.ETSformer": ETSformer,
"pypots.imputation.MICN": MICN,
"pypots.imputation.DLinear": DLinear,
"pypots.imputation.SCINet": SCINet,
"pypots.imputation.NonstationaryTransformer": NonstationaryTransformer,
"pypots.imputation.FiLM": FiLM,
"pypots.imputation.Pyraformer": Pyraformer,
"pypots.imputation.Autoformer": Autoformer,
"pypots.imputation.Informer": Informer,
"pypots.imputation.StemGNN": StemGNN,
"pypots.imputation.TimesNet": TimesNet,
"pypots.imputation.CSDI": CSDI,
"pypots.imputation.USGAN": USGAN,
Expand Down
10 changes: 5 additions & 5 deletions pypots/imputation/film/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class FiLM(BaseNNImputer):
multiscale :
A list including the multiscale factors for the HiPPO projection layers.
ratio :
dropout :
The dropout ratio for the HiPPO projection layers.
It only works when mode_type == 1.
Expand Down Expand Up @@ -103,8 +103,8 @@ def __init__(
n_features: int,
window_size: list,
multiscale: list,
modes1: int,
ratio: float = 0.5,
modes1: int = 32,
dropout: float = 0.5,
mode_type: int = 0,
d_model: int = 128,
ORT_weight: float = 1,
Expand Down Expand Up @@ -135,7 +135,7 @@ def __init__(
self.window_size = window_size
self.multiscale = multiscale
self.modes1 = modes1
self.ratio = ratio
self.dropout = dropout
self.mode_type = mode_type
self.d_model = d_model
self.ORT_weight = ORT_weight
Expand All @@ -148,7 +148,7 @@ def __init__(
self.window_size,
self.multiscale,
self.modes1,
self.ratio,
self.dropout,
self.mode_type,
self.d_model,
self.ORT_weight,
Expand Down
8 changes: 4 additions & 4 deletions pypots/imputation/revinscinet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ def __init__(
n_groups: int,
n_decoder_layers: int,
d_hidden: int,
kernel_size: int,
dropout: float,
concat_len: int,
pos_enc: bool,
kernel_size: int = 3,
concat_len: int = 0,
dropout: float = 0.5,
pos_enc: bool = False,
ORT_weight: float = 1,
MIT_weight: float = 1,
batch_size: int = 32,
Expand Down
8 changes: 4 additions & 4 deletions pypots/imputation/scinet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ def __init__(
n_groups: int,
n_decoder_layers: int,
d_hidden: int,
kernel_size: int,
dropout: float,
concat_len: int,
pos_enc: bool,
kernel_size: int = 3,
concat_len: int = 0,
dropout: float = 0.5,
pos_enc: bool = False,
ORT_weight: float = 1,
MIT_weight: float = 1,
batch_size: int = 32,
Expand Down
2 changes: 1 addition & 1 deletion tests/imputation/film.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class TestFiLM(unittest.TestCase):
window_size=[2],
multiscale=[1, 2],
modes1=512,
ratio=0.5,
dropout=0.5,
d_model=512,
epochs=EPOCHS,
saving_path=saving_path,
Expand Down

0 comments on commit 2668891

Please sign in to comment.