diff --git a/tst/transformer.py b/tst/transformer.py index 1062b13..4d5a070 100644 --- a/tst/transformer.py +++ b/tst/transformer.py @@ -47,9 +47,9 @@ class Transformer(nn.Module): One of ``'chunk'``, ``'window'`` or ``None``. Default is ``'chunk'``. pe: Type of positional encoding to add. - Must be one of ``'original'``, ``'regular'`` or ``None``. Default is ``None``. + Must be one of `original, `regular` or `None. Default is `None`. pe_period: - If using the ``'regular'` pe, then we can define the period. Default is ``24``. + If using the ``regular`` pe, then we can define the period. Default is ``24``. """ def __init__(self, @@ -64,7 +64,7 @@ def __init__(self, dropout: float = 0.3, chunk_mode: str = 'chunk', pe: str = None, - pe_period: int = 24): + pe_period: int = None): """Create transformer structure from Encoder and Decoder blocks.""" super().__init__() @@ -95,7 +95,9 @@ def __init__(self, if pe in pe_functions.keys(): self._generate_PE = pe_functions[pe] - self._pe_period = pe_period + + if pe == 'regular' and pe_period is not None: + self._pe_period = pe_period elif pe is None: self._generate_PE = None else: