From 4acaf92395aa856c70203e0f7aa32155612a0b2a Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Fri, 31 May 2024 19:24:32 +0000 Subject: [PATCH] fix index error in cls token --- mambular/base_models/classifier.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mambular/base_models/classifier.py b/mambular/base_models/classifier.py index 7ba8223..e2abae0 100644 --- a/mambular/base_models/classifier.py +++ b/mambular/base_models/classifier.py @@ -232,9 +232,7 @@ def forward(self, num_features, cat_features): The output predictions of the model. """ batch_size = ( - cat_features[0].size(0) - if cat_features is not None - else num_features[0].size(0) + cat_features[0].size(0) if cat_features != [] else num_features[0].size(0) ) cls_tokens = self.cls_token.expand(batch_size, -1, -1)