From 5575f4b4feaa3c0aaf34758e48fb98d93d12a70b Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Thu, 26 Dec 2024 10:32:14 +0000 Subject: [PATCH] feat(python): Allow loading data from multiple Excel/ODS workbooks and worksheets --- py-polars/polars/io/spreadsheet/functions.py | 48 +++++++++++++++----- py-polars/tests/unit/io/test_spreadsheet.py | 26 ++++++++++- 2 files changed, 61 insertions(+), 13 deletions(-) diff --git a/py-polars/polars/io/spreadsheet/functions.py b/py-polars/polars/io/spreadsheet/functions.py index 1d32589302bc..a0d0b3f431d2 100644 --- a/py-polars/polars/io/spreadsheet/functions.py +++ b/py-polars/polars/io/spreadsheet/functions.py @@ -3,6 +3,7 @@ import os import re import warnings +from collections import defaultdict from collections.abc import Sequence from datetime import time from glob import glob @@ -71,6 +72,31 @@ def _standardize_duplicates(s: str) -> str: return re.sub(r"_duplicated_(\d+)", repl=r"\1", string=s) +def _unpack_sheet_results( + frames: list[pl.DataFrame] | list[dict[str, pl.DataFrame]], + *, + read_multiple_workbooks: bool, +) -> Any: + if not frames: + msg = "no data found in the given workbook(s) and sheet(s)" + raise NoDataError(msg) + + if not read_multiple_workbooks: + # one sheet from one workbook + return frames[0] + + if isinstance(frames[0], pl.DataFrame): + # one sheet from multiple workbooks + return concat(frames, how="vertical_relaxed") # type: ignore[type-var] + else: + # multiple sheets from multiple workbooks + sheet_frames = defaultdict(list) + for res in frames: + for sheet, df in res.items(): # type: ignore[union-attr] + sheet_frames[sheet].append(df) + return {k: concat(v, how="vertical_relaxed") for k, v in sheet_frames.items()} + + @overload def read_excel( source: FileSource, @@ -338,7 +364,7 @@ def read_excel( ... ) # doctest: +SKIP """ sources, read_multiple_workbooks = _sources(source) - frames = [ + frames: list[pl.DataFrame] | list[dict[str, pl.DataFrame]] = [ # type: ignore[assignment] _read_spreadsheet( src, sheet_id=sheet_id, @@ -357,9 +383,10 @@ def read_excel( ) for src in sources ] - if read_multiple_workbooks: - return concat(frames, how="vertical_relaxed") # type: ignore[type-var] - return frames[0] + return _unpack_sheet_results( + frames=frames, + read_multiple_workbooks=read_multiple_workbooks, + ) @overload @@ -540,7 +567,7 @@ def read_ods( ... ) # doctest: +SKIP """ sources, read_multiple_workbooks = _sources(source) - frames = [ + frames: list[pl.DataFrame] | list[dict[str, pl.DataFrame]] = [ # type: ignore[assignment] _read_spreadsheet( src, sheet_id=sheet_id, @@ -559,9 +586,10 @@ def read_ods( ) for src in sources ] - if read_multiple_workbooks: - return concat(frames, how="vertical_relaxed") # type: ignore[type-var] - return frames[0] + return _unpack_sheet_results( + frames=frames, + read_multiple_workbooks=read_multiple_workbooks, + ) def _read_spreadsheet( @@ -605,10 +633,6 @@ def _read_spreadsheet( sheet_names, return_multiple_sheets = _get_sheet_names( sheet_id, sheet_name, worksheets ) - if read_multiple_workbooks and return_multiple_sheets: - msg = "cannot return multiple sheets from multiple workbooks" - raise ValueError(msg) - parsed_sheets = { name: reader_fn( parser=parser, diff --git a/py-polars/tests/unit/io/test_spreadsheet.py b/py-polars/tests/unit/io/test_spreadsheet.py index 433d652b07ec..409a46db4e35 100644 --- a/py-polars/tests/unit/io/test_spreadsheet.py +++ b/py-polars/tests/unit/io/test_spreadsheet.py @@ -184,13 +184,14 @@ def test_read_excel_multiple_worksheets( ], ) def test_read_excel_multiple_workbooks( - read_spreadsheet: Callable[..., pl.DataFrame], + read_spreadsheet: Callable[..., Any], source: str, params: dict[str, str], request: pytest.FixtureRequest, ) -> None: spreadsheet_path = request.getfixturevalue(source) + # multiple workbooks, single worksheet df = read_spreadsheet( [ spreadsheet_path, @@ -206,6 +207,29 @@ def test_read_excel_multiple_workbooks( ) assert_frame_equal(df, expected) + # multiple workbooks, multiple worksheets + res = read_spreadsheet( + [ + spreadsheet_path, + spreadsheet_path, + spreadsheet_path, + ], + sheet_id=None, + sheet_name=["test1", "test2"], + **params, + ) + expected_frames = { + "test1": pl.DataFrame( + {"hello": ["Row 1", "Row 2", "Row 1", "Row 2", "Row 1", "Row 2"]} + ), + "test2": pl.DataFrame( + {"world": ["Row 3", "Row 4", "Row 3", "Row 4", "Row 3", "Row 4"]} + ), + } + assert sorted(res) == sorted(expected_frames) + assert_frame_equal(res["test1"], expected_frames["test1"]) + assert_frame_equal(res["test2"], expected_frames["test2"]) + @pytest.mark.parametrize( ("read_spreadsheet", "source", "params"),