diff --git a/ax/core/data.py b/ax/core/data.py index de2fbde8a19..d22481ff720 100644 --- a/ax/core/data.py +++ b/ax/core/data.py @@ -270,15 +270,14 @@ def from_multiple( Args: data: Iterable of Ax objects of this class to combine. """ - incompatible_types = { - type(datum) for datum in data if not isinstance(datum, cls) - } - if incompatible_types: - raise TypeError( - f"All data objects must be instances of class {cls}. Got " - f"{incompatible_types}." - ) - dfs = [datum.df for datum in data] + dfs = [] + for datum in data: + if not isinstance(datum, cls): + raise TypeError( + f"All data objects must be instances of class {cls}. Got " + f"{type(datum)}." + ) + dfs.append(datum.df) if len(dfs) == 0: return cls() diff --git a/ax/core/tests/test_data.py b/ax/core/tests/test_data.py index 1e6f4d4e875..e5c7e05593d 100644 --- a/ax/core/tests/test_data.py +++ b/ax/core/tests/test_data.py @@ -290,6 +290,10 @@ def test_FromMultipleDataMismatchedTypes(self) -> None: ) Data.from_multiple_data([data_elt_A, data_elt_B]) + def test_from_multiple_with_generator(self) -> None: + data = Data.from_multiple_data(Data(df=self.df) for _ in range(2)) + self.assertEqual(len(data.df), 2 * len(self.df)) + def test_GetFilteredResults(self) -> None: data = Data(df=self.df) # pyre-fixme[6]: For 1st param expected `Dict[str, typing.Any]` but got `str`.