diff --git a/python/fastexcel/__init__.py b/python/fastexcel/__init__.py index 9f7c0ee..430b580 100644 --- a/python/fastexcel/__init__.py +++ b/python/fastexcel/__init__.py @@ -88,6 +88,7 @@ def load_sheet_by_name( column_names: list[str] | None = None, skip_rows: int = 0, n_rows: int | None = None, + schema_sample_rows: int | None = 1_000, ) -> ExcelSheet: """Loads a sheet by name. @@ -97,8 +98,11 @@ def load_sheet_by_name( :param column_names: Overrides headers found in the document. If `column_names` is used, `header_row` will be ignored. :param n_rows: Specifies how many rows should be loaded. If `None`, all rows are loaded - :param skip_rows: Specifies how many should be skipped after the header. If `header_row` is - `None`, it skips the number of rows from the sheet's start. + :param skip_rows: Specifies how many rows should be skipped after the header. If + `header_row` is `None`, it skips the number of rows from the sheet's + start. + :param schema_sample_rows: Specifies how many rows should be used to determine the dtype of + a column. If `None`, all rows will be used. """ return ExcelSheet( self._reader.load_sheet_by_name( @@ -107,6 +111,7 @@ def load_sheet_by_name( column_names=column_names, skip_rows=skip_rows, n_rows=n_rows, + schema_sample_rows=schema_sample_rows, ) ) @@ -118,6 +123,7 @@ def load_sheet_by_idx( column_names: list[str] | None = None, skip_rows: int = 0, n_rows: int | None = None, + schema_sample_rows: int | None = 1_000, ) -> ExcelSheet: """Loads a sheet by index. @@ -127,8 +133,11 @@ def load_sheet_by_idx( :param column_names: Overrides headers found in the document. If `column_names` is used, `header_row` will be ignored. :param n_rows: Specifies how many rows should be loaded. If `None`, all rows are loaded - :param skip_rows: Specifies how many should be skipped after the header. If `header_row` is - `None`, it skips the number of rows from the sheet's start. + :param skip_rows: Specifies how many rows should be skipped after the header. If + `header_row` is `None`, it skips the number of rows from the sheet's + start. + :param schema_sample_rows: Specifies how many rows should be used to determine the dtype of + a column. If `None`, all rows will be used. """ if idx < 0: raise ValueError(f"Expected idx to be > 0, got {idx}") @@ -139,6 +148,7 @@ def load_sheet_by_idx( column_names=column_names, skip_rows=skip_rows, n_rows=n_rows, + schema_sample_rows=schema_sample_rows, ) ) @@ -150,6 +160,7 @@ def load_sheet( column_names: list[str] | None = None, skip_rows: int = 0, n_rows: int | None = None, + schema_sample_rows: int | None = 1_000, ) -> ExcelSheet: """Loads a sheet by name if a string is passed or by index if an integer is passed. @@ -162,6 +173,7 @@ def load_sheet( column_names=column_names, skip_rows=skip_rows, n_rows=n_rows, + schema_sample_rows=schema_sample_rows, ) if isinstance(idx_or_name, int) else self.load_sheet_by_name( @@ -170,6 +182,7 @@ def load_sheet( column_names=column_names, skip_rows=skip_rows, n_rows=n_rows, + schema_sample_rows=schema_sample_rows, ) ) diff --git a/python/fastexcel/_fastexcel.pyi b/python/fastexcel/_fastexcel.pyi index b9fd5eb..26e1841 100644 --- a/python/fastexcel/_fastexcel.pyi +++ b/python/fastexcel/_fastexcel.pyi @@ -32,6 +32,7 @@ class _ExcelReader: column_names: list[str] | None = None, skip_rows: int = 0, n_rows: int | None = None, + schema_sample_rows: int | None = 1_000, ) -> _ExcelSheet: ... def load_sheet_by_idx( self, @@ -41,6 +42,7 @@ class _ExcelReader: column_names: list[str] | None = None, skip_rows: int = 0, n_rows: int | None = None, + schema_sample_rows: int | None = 1_000, ) -> _ExcelSheet: ... def load_sheet( self, @@ -50,6 +52,7 @@ class _ExcelReader: column_names: list[str] | None = None, skip_rows: int = 0, n_rows: int | None = None, + schema_sample_rows: int | None = 1_000, ) -> _ExcelSheet: ... @property def sheet_names(self) -> list[str]: ... diff --git a/python/tests/test_dtypes.py b/python/tests/test_dtypes.py index 19b3455..6798de7 100644 --- a/python/tests/test_dtypes.py +++ b/python/tests/test_dtypes.py @@ -1,18 +1,20 @@ +from __future__ import annotations + from datetime import datetime +from typing import Any import fastexcel import pandas as pd import polars as pl +import pytest from pandas.testing import assert_frame_equal as pd_assert_frame_equal from polars.testing import assert_frame_equal as pl_assert_frame_equal from utils import path_for_fixture -def test_sheet_with_mixed_dtypes() -> None: - excel_reader = fastexcel.read_excel(path_for_fixture("fixture-multi-dtypes-columns.xlsx")) - sheet = excel_reader.load_sheet(0) - - expected_data = { +@pytest.fixture +def expected_data() -> dict[str, list[Any]]: + return { "Employee ID": [ "123456", "44333", @@ -40,6 +42,54 @@ def test_sheet_with_mixed_dtypes() -> None: "Asset ID": ["84444"] * 7 + ["ABC123"] * 2, } + +def test_sheet_with_mixed_dtypes(expected_data: dict[str, list[Any]]) -> None: + excel_reader = fastexcel.read_excel(path_for_fixture("fixture-multi-dtypes-columns.xlsx")) + sheet = excel_reader.load_sheet(0) + + pd_df = sheet.to_pandas() + pd_assert_frame_equal(pd_df, pd.DataFrame(expected_data).astype({"Date": "datetime64[ms]"})) + + pl_df = sheet.to_polars() + pl_assert_frame_equal( + pl_df, pl.DataFrame(expected_data, schema_overrides={"Date": pl.Datetime(time_unit="ms")}) + ) + + +def test_sheet_with_mixed_dtypes_and_sample_rows(expected_data: dict[str, list[Any]]) -> None: + excel_reader = fastexcel.read_excel(path_for_fixture("fixture-multi-dtypes-columns.xlsx")) + + # Since we skip rows here, the dtypes should be correctly guessed, even if we only check 5 rows + sheet = excel_reader.load_sheet(0, schema_sample_rows=5, skip_rows=5) + + expected_data_subset = {col_name: values[5:] for col_name, values in expected_data.items()} + pd_df = sheet.to_pandas() + pd_assert_frame_equal( + pd_df, pd.DataFrame(expected_data_subset).astype({"Date": "datetime64[ms]"}) + ) + + pl_df = sheet.to_polars() + pl_assert_frame_equal( + pl_df, + pl.DataFrame(expected_data_subset, schema_overrides={"Date": pl.Datetime(time_unit="ms")}), + ) + + # Guess the sheet's dtypes on 5 rows only + sheet = excel_reader.load_sheet(0, schema_sample_rows=5) + # String fields should not have been loaded + expected_data["Employee ID"] = [ + 123456.0, + 44333.0, + 44333.0, + 87878.0, + 87878.0, + None, + 135967.0, + None, + None, + ] + expected_data["Asset ID"] = [84444.0] * 7 + [None] * 2 + pd_df = sheet.to_pandas() pd_assert_frame_equal(pd_df, pd.DataFrame(expected_data).astype({"Date": "datetime64[ms]"})) diff --git a/src/types/excelreader.rs b/src/types/excelreader.rs index 592b235..0c9aadd 100644 --- a/src/types/excelreader.rs +++ b/src/types/excelreader.rs @@ -44,7 +44,8 @@ impl ExcelReader { header_row = 0, column_names = None, skip_rows = 0, - n_rows = None + n_rows = None, + schema_sample_rows = 1_000, ))] pub fn load_sheet_by_name( &mut self, @@ -53,6 +54,7 @@ impl ExcelReader { column_names: Option>, skip_rows: usize, n_rows: Option, + schema_sample_rows: Option, ) -> Result { let range = self .sheets @@ -61,7 +63,13 @@ impl ExcelReader { let header = Header::new(header_row, column_names); let pagination = Pagination::new(skip_rows, n_rows, &range)?; - Ok(ExcelSheet::new(name, range, header, pagination)) + Ok(ExcelSheet::new( + name, + range, + header, + pagination, + schema_sample_rows, + )) } #[pyo3(signature = ( @@ -70,8 +78,9 @@ impl ExcelReader { header_row = 0, column_names = None, skip_rows = 0, - n_rows = None) - )] + n_rows = None, + schema_sample_rows = 1_000, + ))] pub fn load_sheet_by_idx( &mut self, idx: usize, @@ -79,6 +88,7 @@ impl ExcelReader { column_names: Option>, skip_rows: usize, n_rows: Option, + schema_sample_rows: Option, ) -> Result { let name = self .sheet_names @@ -98,6 +108,12 @@ impl ExcelReader { let header = Header::new(header_row, column_names); let pagination = Pagination::new(skip_rows, n_rows, &range)?; - Ok(ExcelSheet::new(name, range, header, pagination)) + Ok(ExcelSheet::new( + name, + range, + header, + pagination, + schema_sample_rows, + )) } } diff --git a/src/types/excelsheet.rs b/src/types/excelsheet.rs index b12e2fd..e7ec2eb 100644 --- a/src/types/excelsheet.rs +++ b/src/types/excelsheet.rs @@ -76,6 +76,7 @@ pub(crate) struct ExcelSheet { height: Option, total_height: Option, width: Option, + schema_sample_rows: Option, } impl ExcelSheet { @@ -88,12 +89,14 @@ impl ExcelSheet { data: Range, header: Header, pagination: Pagination, + schema_sample_rows: Option, ) -> Self { ExcelSheet { name, header, pagination, data, + schema_sample_rows, height: None, total_height: None, width: None, @@ -138,6 +141,10 @@ impl ExcelSheet { upper_bound } + + pub(crate) fn schema_sample_rows(&self) -> &Option { + &self.schema_sample_rows + } } fn create_boolean_array( @@ -240,11 +247,16 @@ impl TryFrom<&ExcelSheet> for Schema { type Error = anyhow::Error; fn try_from(value: &ExcelSheet) -> Result { + // Checking how many rows we want to use to determine the dtype for a column. If sample_rows is + // not provided, we sample limit rows, i.e on the entire column + let sample_rows = value.offset() + value.schema_sample_rows().unwrap_or(value.limit()); + arrow_schema_from_column_names_and_range( value.data(), &value.column_names(), value.offset(), - value.limit(), + // If sample_rows is higher than the sheet's limit, use the limit instead + std::cmp::min(sample_rows, value.limit()), ) } }