From 9d37e905639e0d4983e52ce425306b8161760ee4 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Thu, 24 Oct 2024 12:33:40 -0700 Subject: [PATCH] Implement MultiTaskDataset.__eq__ (#2594) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/2594 Previously, this would fallback to `SupervisedDataset.__eq__`, which uses `self.X` for comparison. If the underlying datasets have heterogeneous feature sets, `self.X` errors out. The new `MultiTaskDataset.__eq__` resolves this issue by comparing the underlying datasets one by one. Reviewed By: Balandat Differential Revision: D64911436 fbshipit-source-id: ecb7343d86c4526d06f61725c1663e50f1f1902f --- botorch/utils/datasets.py | 10 +++++++++- test/utils/test_datasets.py | 11 +++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/botorch/utils/datasets.py b/botorch/utils/datasets.py index 7afa0c9ca4..f11f5c80e7 100644 --- a/botorch/utils/datasets.py +++ b/botorch/utils/datasets.py @@ -492,6 +492,14 @@ def get_dataset_without_task_feature(self, outcome_name: str) -> SupervisedDatas outcome_names=[outcome_name], ) + def __eq__(self, other: Any) -> bool: + return ( + type(other) is type(self) + and self.datasets == other.datasets + and self.target_outcome_name == other.target_outcome_name + and self.task_feature_index == other.task_feature_index + ) + class ContextualDataset(SupervisedDataset): """This is a contextual dataset that is constructed from either a single @@ -548,7 +556,7 @@ def Y(self) -> Tensor: return torch.cat(Ys, dim=-1) @property - def Yvar(self) -> Tensor: + def Yvar(self) -> Tensor | None: """Concatenates the Yvars from the child datasets to create the Y expected by LCEM model if there are multiple datasets; Or return the Yvar expected by LCEA model if there is only one dataset. diff --git a/test/utils/test_datasets.py b/test/utils/test_datasets.py index 5301066ca0..22d8c24a50 100644 --- a/test/utils/test_datasets.py +++ b/test/utils/test_datasets.py @@ -354,6 +354,17 @@ def test_multi_task(self): ): mt_dataset.X + # Test equality. + self.assertEqual(mt_dataset, mt_dataset) + self.assertNotEqual(mt_dataset, dataset_5) + self.assertNotEqual( + mt_dataset, MultiTaskDataset(datasets=[dataset_1], target_outcome_name="y") + ) + self.assertNotEqual( + mt_dataset, + MultiTaskDataset(datasets=[dataset_1, dataset_5], target_outcome_name="z"), + ) + def test_contextual_datasets(self): num_contexts = 3 feature_names = [f"x_c{i}" for i in range(num_contexts)]