Skip to content

Commit

Permalink
feat(python): Allow loading data from multiple Excel/ODS workbooks an…
Browse files Browse the repository at this point in the history
…d worksheets
  • Loading branch information
alexander-beedie committed Dec 26, 2024
1 parent a6fffd4 commit 5575f4b
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 13 deletions.
48 changes: 36 additions & 12 deletions py-polars/polars/io/spreadsheet/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import re
import warnings
from collections import defaultdict
from collections.abc import Sequence
from datetime import time
from glob import glob
Expand Down Expand Up @@ -71,6 +72,31 @@ def _standardize_duplicates(s: str) -> str:
return re.sub(r"_duplicated_(\d+)", repl=r"\1", string=s)


def _unpack_sheet_results(
frames: list[pl.DataFrame] | list[dict[str, pl.DataFrame]],
*,
read_multiple_workbooks: bool,
) -> Any:
if not frames:
msg = "no data found in the given workbook(s) and sheet(s)"
raise NoDataError(msg)

if not read_multiple_workbooks:
# one sheet from one workbook
return frames[0]

if isinstance(frames[0], pl.DataFrame):
# one sheet from multiple workbooks
return concat(frames, how="vertical_relaxed") # type: ignore[type-var]
else:
# multiple sheets from multiple workbooks
sheet_frames = defaultdict(list)
for res in frames:
for sheet, df in res.items(): # type: ignore[union-attr]
sheet_frames[sheet].append(df)
return {k: concat(v, how="vertical_relaxed") for k, v in sheet_frames.items()}


@overload
def read_excel(
source: FileSource,
Expand Down Expand Up @@ -338,7 +364,7 @@ def read_excel(
... ) # doctest: +SKIP
"""
sources, read_multiple_workbooks = _sources(source)
frames = [
frames: list[pl.DataFrame] | list[dict[str, pl.DataFrame]] = [ # type: ignore[assignment]
_read_spreadsheet(
src,
sheet_id=sheet_id,
Expand All @@ -357,9 +383,10 @@ def read_excel(
)
for src in sources
]
if read_multiple_workbooks:
return concat(frames, how="vertical_relaxed") # type: ignore[type-var]
return frames[0]
return _unpack_sheet_results(
frames=frames,
read_multiple_workbooks=read_multiple_workbooks,
)


@overload
Expand Down Expand Up @@ -540,7 +567,7 @@ def read_ods(
... ) # doctest: +SKIP
"""
sources, read_multiple_workbooks = _sources(source)
frames = [
frames: list[pl.DataFrame] | list[dict[str, pl.DataFrame]] = [ # type: ignore[assignment]
_read_spreadsheet(
src,
sheet_id=sheet_id,
Expand All @@ -559,9 +586,10 @@ def read_ods(
)
for src in sources
]
if read_multiple_workbooks:
return concat(frames, how="vertical_relaxed") # type: ignore[type-var]
return frames[0]
return _unpack_sheet_results(
frames=frames,
read_multiple_workbooks=read_multiple_workbooks,
)


def _read_spreadsheet(
Expand Down Expand Up @@ -605,10 +633,6 @@ def _read_spreadsheet(
sheet_names, return_multiple_sheets = _get_sheet_names(
sheet_id, sheet_name, worksheets
)
if read_multiple_workbooks and return_multiple_sheets:
msg = "cannot return multiple sheets from multiple workbooks"
raise ValueError(msg)

parsed_sheets = {
name: reader_fn(
parser=parser,
Expand Down
26 changes: 25 additions & 1 deletion py-polars/tests/unit/io/test_spreadsheet.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,14 @@ def test_read_excel_multiple_worksheets(
],
)
def test_read_excel_multiple_workbooks(
read_spreadsheet: Callable[..., pl.DataFrame],
read_spreadsheet: Callable[..., Any],
source: str,
params: dict[str, str],
request: pytest.FixtureRequest,
) -> None:
spreadsheet_path = request.getfixturevalue(source)

# multiple workbooks, single worksheet
df = read_spreadsheet(
[
spreadsheet_path,
Expand All @@ -206,6 +207,29 @@ def test_read_excel_multiple_workbooks(
)
assert_frame_equal(df, expected)

# multiple workbooks, multiple worksheets
res = read_spreadsheet(
[
spreadsheet_path,
spreadsheet_path,
spreadsheet_path,
],
sheet_id=None,
sheet_name=["test1", "test2"],
**params,
)
expected_frames = {
"test1": pl.DataFrame(
{"hello": ["Row 1", "Row 2", "Row 1", "Row 2", "Row 1", "Row 2"]}
),
"test2": pl.DataFrame(
{"world": ["Row 3", "Row 4", "Row 3", "Row 4", "Row 3", "Row 4"]}
),
}
assert sorted(res) == sorted(expected_frames)
assert_frame_equal(res["test1"], expected_frames["test1"])
assert_frame_equal(res["test2"], expected_frames["test2"])


@pytest.mark.parametrize(
("read_spreadsheet", "source", "params"),
Expand Down

0 comments on commit 5575f4b

Please sign in to comment.