Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fixes
Browse files Browse the repository at this point in the history
lostella committed May 27, 2024
1 parent 4486690 commit becf7fe
Showing 3 changed files with 18 additions and 13 deletions.
19 changes: 12 additions & 7 deletions src/gluonts/torch/model/mq_cnn/estimator.py
Original file line number Diff line number Diff line change
@@ -74,7 +74,7 @@ class MQCNNEstimator(PyTorchLightningEstimator):
Length of the prediction, also known as 'horizon'.
context_length (int, optional):
Number of time units that condition the predictions, also known as 'lookback period'.
Defaults to None.
Defaults to `4 * prediction_length`.
num_forking (int, optional):
Decides how much forking to do in the decoder.
(default: context_length if None)
@@ -282,7 +282,11 @@ def __init__(

self.freq = freq
self.prediction_length = prediction_length
self.context_length = context_length
self.context_length = self.context_length = (
context_length
if context_length is not None
else 4 * self.prediction_length
)
self.num_forking = (
min(num_forking, self.context_length)
if num_forking is not None
@@ -400,7 +404,8 @@ def __init__(
self.enc_cnn_init_dim += sum(self.embedding_dimension_dynamic)
self.dec_future_init_dim += sum(self.embedding_dimension_dynamic)
else:
self.cardinality_dynamic = 0
self.cardinality_dynamic = [0]
self.embedding_dimension_dynamic = [0]

if self.use_past_feat_dynamic_real:
assert (
@@ -431,14 +436,14 @@ def __init__(
)
self.enc_mlp_init_dim += sum(self.embedding_dimension_static)
else:
self.cardinality_static = 0
self.embedding_dimension_static = 0
self.cardinality_static = [0]
self.embedding_dimension_static = [0]

self.joint_embedding_dimension = joint_embedding_dimension
if self.joint_embedding_dimension is None:
feat_static_dim = sum(self.embedding_dimension_static)
self.joint_embedding_dimension = int(
channels_seq[-1] * max(np.sqrt(feat_static_dim), 1)
self.channels_seq[-1] * max(np.sqrt(feat_static_dim), 1)
)

if self.use_feat_static_real:
@@ -623,7 +628,7 @@ def _create_instance_splitter(self, mode: str) -> Chain:
FieldName.FEAT_DYNAMIC,
FieldName.FEAT_DYNAMIC_CAT,
]
+ ([FieldName.OBSERVED_VALUES] if mode != "test" else []),
+ ([FieldName.OBSERVED_VALUES] if mode != "test" else [])
)

decoder_disabled_fields = (
4 changes: 2 additions & 2 deletions src/gluonts/torch/model/mq_cnn/lightning_module.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from pytorch_lightning import LightningModule
import lightning.pytorch as pl
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from gluonts.torch.model.lightning_util import has_validation_loop
@@ -21,7 +21,7 @@
from .module import MQCNNModel


class MQCNNLightningModule(LightningModule):
class MQCNNLightningModule(pl.LightningModule):
"""
A ``pl.LightningModule`` class that can be used to train a
``MQCNNModel`` with PyTorch Lightning.
8 changes: 4 additions & 4 deletions test/torch/model/test_estimators.py
Original file line number Diff line number Diff line change
@@ -321,10 +321,10 @@ def test_estimator_constant_dataset(
prediction_length=prediction_length,
batch_size=4,
num_batches_per_epoch=3,
num_feat_dynamic_real=3,
num_feat_static_real=1,
num_feat_static_cat=2,
cardinality=[2, 2],
use_feat_dynamic_real=True,
use_feat_static_real=True,
use_feat_static_cat=True,
cardinality_static=[2, 2],
trainer_kwargs=dict(max_epochs=2),
),
],

0 comments on commit becf7fe

Please sign in to comment.