Skip to content

Commit

Permalink
Properly handle generators in Data.from_multiple (#2318)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2318

Since it previously looped through the input twice, it would consume the generator in the first pass and return empty data.

Reviewed By: Balandat

Differential Revision: D55723466

fbshipit-source-id: da77641218c1b0f8a13764893a937735d2d9c515
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Apr 4, 2024
1 parent 032a4f0 commit 1b29a78
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
17 changes: 8 additions & 9 deletions ax/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,15 +270,14 @@ def from_multiple(
Args:
data: Iterable of Ax objects of this class to combine.
"""
incompatible_types = {
type(datum) for datum in data if not isinstance(datum, cls)
}
if incompatible_types:
raise TypeError(
f"All data objects must be instances of class {cls}. Got "
f"{incompatible_types}."
)
dfs = [datum.df for datum in data]
dfs = []
for datum in data:
if not isinstance(datum, cls):
raise TypeError(
f"All data objects must be instances of class {cls}. Got "
f"{type(datum)}."
)
dfs.append(datum.df)

if len(dfs) == 0:
return cls()
Expand Down
4 changes: 4 additions & 0 deletions ax/core/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,10 @@ def test_FromMultipleDataMismatchedTypes(self) -> None:
)
Data.from_multiple_data([data_elt_A, data_elt_B])

def test_from_multiple_with_generator(self) -> None:
data = Data.from_multiple_data(Data(df=self.df) for _ in range(2))
self.assertEqual(len(data.df), 2 * len(self.df))

def test_GetFilteredResults(self) -> None:
data = Data(df=self.df)
# pyre-fixme[6]: For 1st param expected `Dict[str, typing.Any]` but got `str`.
Expand Down

0 comments on commit 1b29a78

Please sign in to comment.