Skip to content

Commit

Permalink
Update deeptcn model
Browse files Browse the repository at this point in the history
  • Loading branch information
PvtKaefsky committed May 17, 2024
1 parent 53dbd5b commit 0ad95e1
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions fedot_ind/core/models/nn/network_impl/deep_tcn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Optional, Union
from typing import Optional
import math

import pandas as pd
Expand All @@ -13,19 +13,19 @@
from fedot.core.operations.operation_parameters import OperationParameters
from fedot.core.repository.dataset_types import DataTypesEnum
from fedot.core.repository.tasks import Task, TaskTypesEnum, TsForecastingParams
from torch import nn, optim
from torch import nn, optim, Tensor
from torch.optim import lr_scheduler

from fedot_ind.core.architecture.abstraction.decorators import convert_to_4d_torch_array
from fedot_ind.core.architecture.abstraction.decorators import convert_inputdata_to_torch_time_series_dataset
from fedot_ind.core.architecture.preprocessing.data_convertor import DataConverter
from fedot_ind.core.architecture.settings.computational import backend_methods as np
from fedot_ind.core.architecture.settings.computational import default_device
from fedot_ind.core.models.nn.network_impl.base_nn_model import BaseNeuralModel
from fedot_ind.core.models.nn.network_modules.layers.backbone import _PatchTST_backbone
from fedot_ind.core.models.nn.network_modules.layers.special import adjust_learning_rate, EarlyStopping
from fedot_ind.core.operation.transformation.data.hankel import HankelMatrix
from fedot_ind.core.operation.transformation.window_selector import WindowSizeSelector
from fedot_ind.core.repository.constanst_repository import EXPONENTIAL_WEIGHTED_LOSS
from fedot_ind.core.repository.constanst_repository import RMSE

warnings.filterwarnings("ignore", category=UserWarning)

Expand Down Expand Up @@ -168,10 +168,10 @@ def __init__(self, params: Optional[OperationParameters] = {}):

self.kernel_size = params.get('kernel_size', 3)
self.num_filters = params.get('num_filters', 3)
self.num_layers = params.get('num_filters', None)
self.num_layers = params.get('num_layers', None)
self.dilation_base = params.get('dilation_base', 2)
self.dropout = params.get('dropout', 0.2)
self.weight_norm = params.get('dropout', False)
self.weight_norm = params.get('weight_norm', False)

def _init_model(self, ts):
model = _TCNModule(input_size=ts.features.shape[0],
Expand All @@ -185,9 +185,7 @@ def _init_model(self, ts):
dropout=self.dropout,
weight_norm=self.weight_norm)
optimizer = optim.Adam(model.parameters(), lr=self.learning_rate)
patch_pred_len = round(self.horizon / 4)
loss_fn = EXPONENTIAL_WEIGHTED_LOSS(
time_steps=patch_pred_len, tolerance=0.3)
loss_fn = RMSE()
return model, loss_fn, optimizer

# @property
Expand Down Expand Up @@ -438,3 +436,10 @@ def _predict_loop(self, model,
test_loader = torch.utils.data.DataLoader(data.TensorDataset(features, target),
batch_size=self.batch_size, shuffle=False)
return self._predict(model, test_loader)

@convert_to_4d_torch_array
def _predict_model(self, x_test):
self.model.eval()
x_test = Tensor(x_test).to(default_device())
pred = self.model(x_test)
return self._convert_predict(pred)

0 comments on commit 0ad95e1

Please sign in to comment.