Skip to content

Commit

Permalink
fix index error in cls token
Browse files Browse the repository at this point in the history
  • Loading branch information
AnFreTh committed May 31, 2024
1 parent 141c595 commit 4acaf92
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions mambular/base_models/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,7 @@ def forward(self, num_features, cat_features):
The output predictions of the model.
"""
batch_size = (
cat_features[0].size(0)
if cat_features is not None
else num_features[0].size(0)
cat_features[0].size(0) if cat_features != [] else num_features[0].size(0)
)
cls_tokens = self.cls_token.expand(batch_size, -1, -1)

Expand Down

0 comments on commit 4acaf92

Please sign in to comment.