diff --git a/tsfm_public/models/tinytimemixer/modeling_tinytimemixer.py b/tsfm_public/models/tinytimemixer/modeling_tinytimemixer.py index 4cb56ef..a4749c5 100644 --- a/tsfm_public/models/tinytimemixer/modeling_tinytimemixer.py +++ b/tsfm_public/models/tinytimemixer/modeling_tinytimemixer.py @@ -1840,7 +1840,9 @@ def forward( ) sequence_length = ( - self.config.masked_context_length if self.config.masked_context_length is not None else self.context_length + self.config.masked_context_length + if self.config.masked_context_length is not None + else self.config.context_length ) if past_values.shape[1] > sequence_length: