Skip to content

Commit

Permalink
float to float32 in datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
mkumar73 committed May 31, 2024
1 parent 9b3443b commit 77ca63e
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions mambular/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,14 @@ def __getitem__(self, idx):
feature_tensor[idx] for feature_tensor in self.cat_features_list
]
num_features = [
feature_tensor[idx] for feature_tensor in self.num_features_list
torch.tensor(feature_tensor[idx], dtype=torch.float32) for feature_tensor in self.num_features_list

]
label = self.labels[idx]
if self.regression:
# Convert the label to float for regression tasks
label = float(label)
# label = float(label)
label = torch.tensor(label, dtype=torch.float32)

# Keep categorical and numerical features separate
return cat_features, num_features, label
Expand Down Expand Up @@ -114,12 +116,13 @@ def __getitem__(self, idx):
feature_tensor[idx] for feature_tensor in self.cat_features_list
]
num_features = [
feature_tensor[idx] for feature_tensor in self.num_features_list
torch.tensor(feature_tensor[idx], dtype=torch.float32) for feature_tensor in self.num_features_list
]
label = self.labels[idx]
if self.regression:
# Convert the label to float for regression tasks
label = float(label)
# label = float(label)
label = torch.tensor(label, dtype=torch.float32)

# Keep categorical and numerical features separate
return cat_features, num_features, label

0 comments on commit 77ca63e

Please sign in to comment.