diff --git a/botorch/utils/datasets.py b/botorch/utils/datasets.py index 7afa0c9ca4..f11f5c80e7 100644 --- a/botorch/utils/datasets.py +++ b/botorch/utils/datasets.py @@ -492,6 +492,14 @@ def get_dataset_without_task_feature(self, outcome_name: str) -> SupervisedDatas outcome_names=[outcome_name], ) + def __eq__(self, other: Any) -> bool: + return ( + type(other) is type(self) + and self.datasets == other.datasets + and self.target_outcome_name == other.target_outcome_name + and self.task_feature_index == other.task_feature_index + ) + class ContextualDataset(SupervisedDataset): """This is a contextual dataset that is constructed from either a single @@ -548,7 +556,7 @@ def Y(self) -> Tensor: return torch.cat(Ys, dim=-1) @property - def Yvar(self) -> Tensor: + def Yvar(self) -> Tensor | None: """Concatenates the Yvars from the child datasets to create the Y expected by LCEM model if there are multiple datasets; Or return the Yvar expected by LCEA model if there is only one dataset. diff --git a/test/utils/test_datasets.py b/test/utils/test_datasets.py index 5301066ca0..22d8c24a50 100644 --- a/test/utils/test_datasets.py +++ b/test/utils/test_datasets.py @@ -354,6 +354,17 @@ def test_multi_task(self): ): mt_dataset.X + # Test equality. + self.assertEqual(mt_dataset, mt_dataset) + self.assertNotEqual(mt_dataset, dataset_5) + self.assertNotEqual( + mt_dataset, MultiTaskDataset(datasets=[dataset_1], target_outcome_name="y") + ) + self.assertNotEqual( + mt_dataset, + MultiTaskDataset(datasets=[dataset_1, dataset_5], target_outcome_name="z"), + ) + def test_contextual_datasets(self): num_contexts = 3 feature_names = [f"x_c{i}" for i in range(num_contexts)]