Skip to content

Commit

Permalink
[ENH] refactor __init__ modules to no longer contain classes (#1738)
Browse files Browse the repository at this point in the history
This PR refactors the `__init__` modules in `models`, moving any content
such as neural network classes to submodules.

Imports will continue working unchanged, as the native contents of the
module are still exported.

This is preparing a further structural refactors.

The moved files are also not changed, in order to retain the commit as
one that moves files, hence retaining edit history. In order for this to
take effect on `main`, PR
#1739 must be merged
first (even though this will result in an intermediate, non-passing
state).
  • Loading branch information
fkiraly authored Dec 26, 2024
1 parent 1d60aa1 commit 0aac3e2
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 0 deletions.
5 changes: 5 additions & 0 deletions pytorch_forecasting/models/deepar/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""DeepAR: Probabilistic forecasting with autoregressive recurrent networks."""

from pytorch_forecasting.models.deepar._deepar import DeepAR

__all__ = ["DeepAR"]
6 changes: 6 additions & 0 deletions pytorch_forecasting/models/mlp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Simple models based on fully connected networks."""

from pytorch_forecasting.models.mlp._decodermlp import DecoderMLP
from pytorch_forecasting.models.mlp.submodules import FullyConnectedModule

__all__ = ["DecoderMLP", "FullyConnectedModule"]
6 changes: 6 additions & 0 deletions pytorch_forecasting/models/nbeats/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""N-Beats model for timeseries forecasting without covariates."""

from pytorch_forecasting.models.nbeats._nbeats import NBeats
from pytorch_forecasting.models.nbeats.sub_modules import NBEATSGenericBlock, NBEATSSeasonalBlock, NBEATSTrendBlock

__all__ = ["NBeats", "NBEATSGenericBlock", "NBEATSSeasonalBlock", "NBEATSTrendBlock"]
6 changes: 6 additions & 0 deletions pytorch_forecasting/models/nhits/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""N-HiTS model for timeseries forecasting with covariates."""

from pytorch_forecasting.models.nhits._nhits import NHiTS
from pytorch_forecasting.models.nhits.sub_modules import NHiTS as NHiTSModule

__all__ = ["NHits", "NHiTSModule"]
5 changes: 5 additions & 0 deletions pytorch_forecasting/models/rnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Simple recurrent model - either with LSTM or GRU cells."""

from pytorch_forecasting.models.rnn._rnn import RecurrentNetwork

__all__ = ["RecurrentNetwork"]
21 changes: 21 additions & 0 deletions pytorch_forecasting/models/temporal_fusion_transformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Temporal fusion transformer for forecasting timeseries."""

from pytorch_forecasting.models.temporal_fusion_transformer._tft import TemporalFusionTransformer
from pytorch_forecasting.models.temporal_fusion_transformer.sub_modules import (
AddNorm,
GateAddNorm,
GatedLinearUnit,
GatedResidualNetwork,
InterpretableMultiHeadAttention,
VariableSelectionNetwork,
)

__all__ = [
"TemporalFusionTransformer",
"AddNorm",
"GateAddNorm",
"GatedLinearUnit",
"GatedResidualNetwork",
"InterpretableMultiHeadAttention",
"VariableSelectionNetwork",
]

0 comments on commit 0aac3e2

Please sign in to comment.