Skip to content

Commit

Permalink
feat: added a schema_sample_rows param
Browse files Browse the repository at this point in the history
Signed-off-by: Luka Peschke <[email protected]>
  • Loading branch information
lukapeschke committed Feb 9, 2024
1 parent e2f91a8 commit e4a69bc
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 15 deletions.
21 changes: 17 additions & 4 deletions python/fastexcel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def load_sheet_by_name(
column_names: list[str] | None = None,
skip_rows: int = 0,
n_rows: int | None = None,
schema_sample_rows: int | None = 1_000,
) -> ExcelSheet:
"""Loads a sheet by name.
Expand All @@ -97,8 +98,11 @@ def load_sheet_by_name(
:param column_names: Overrides headers found in the document. If `column_names` is used,
`header_row` will be ignored.
:param n_rows: Specifies how many rows should be loaded. If `None`, all rows are loaded
:param skip_rows: Specifies how many should be skipped after the header. If `header_row` is
`None`, it skips the number of rows from the sheet's start.
:param skip_rows: Specifies how many rows should be skipped after the header. If
`header_row` is `None`, it skips the number of rows from the sheet's
start.
:param schema_sample_rows: Specifies how many rows should be used to determine the dtype of
a column. If `None`, all rows will be used.
"""
return ExcelSheet(
self._reader.load_sheet_by_name(
Expand All @@ -107,6 +111,7 @@ def load_sheet_by_name(
column_names=column_names,
skip_rows=skip_rows,
n_rows=n_rows,
schema_sample_rows=schema_sample_rows,
)
)

Expand All @@ -118,6 +123,7 @@ def load_sheet_by_idx(
column_names: list[str] | None = None,
skip_rows: int = 0,
n_rows: int | None = None,
schema_sample_rows: int | None = 1_000,
) -> ExcelSheet:
"""Loads a sheet by index.
Expand All @@ -127,8 +133,11 @@ def load_sheet_by_idx(
:param column_names: Overrides headers found in the document. If `column_names` is used,
`header_row` will be ignored.
:param n_rows: Specifies how many rows should be loaded. If `None`, all rows are loaded
:param skip_rows: Specifies how many should be skipped after the header. If `header_row` is
`None`, it skips the number of rows from the sheet's start.
:param skip_rows: Specifies how many rows should be skipped after the header. If
`header_row` is `None`, it skips the number of rows from the sheet's
start.
:param schema_sample_rows: Specifies how many rows should be used to determine the dtype of
a column. If `None`, all rows will be used.
"""
if idx < 0:
raise ValueError(f"Expected idx to be > 0, got {idx}")
Expand All @@ -139,6 +148,7 @@ def load_sheet_by_idx(
column_names=column_names,
skip_rows=skip_rows,
n_rows=n_rows,
schema_sample_rows=schema_sample_rows,
)
)

Expand All @@ -150,6 +160,7 @@ def load_sheet(
column_names: list[str] | None = None,
skip_rows: int = 0,
n_rows: int | None = None,
schema_sample_rows: int | None = 1_000,
) -> ExcelSheet:
"""Loads a sheet by name if a string is passed or by index if an integer is passed.
Expand All @@ -162,6 +173,7 @@ def load_sheet(
column_names=column_names,
skip_rows=skip_rows,
n_rows=n_rows,
schema_sample_rows=schema_sample_rows,
)
if isinstance(idx_or_name, int)
else self.load_sheet_by_name(
Expand All @@ -170,6 +182,7 @@ def load_sheet(
column_names=column_names,
skip_rows=skip_rows,
n_rows=n_rows,
schema_sample_rows=schema_sample_rows,
)
)

Expand Down
3 changes: 3 additions & 0 deletions python/fastexcel/_fastexcel.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class _ExcelReader:
column_names: list[str] | None = None,
skip_rows: int = 0,
n_rows: int | None = None,
schema_sample_rows: int | None = 1_000,
) -> _ExcelSheet: ...
def load_sheet_by_idx(
self,
Expand All @@ -41,6 +42,7 @@ class _ExcelReader:
column_names: list[str] | None = None,
skip_rows: int = 0,
n_rows: int | None = None,
schema_sample_rows: int | None = 1_000,
) -> _ExcelSheet: ...
def load_sheet(
self,
Expand All @@ -50,6 +52,7 @@ class _ExcelReader:
column_names: list[str] | None = None,
skip_rows: int = 0,
n_rows: int | None = None,
schema_sample_rows: int | None = 1_000,
) -> _ExcelSheet: ...
@property
def sheet_names(self) -> list[str]: ...
Expand Down
60 changes: 55 additions & 5 deletions python/tests/test_dtypes.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from __future__ import annotations

from datetime import datetime
from typing import Any

import fastexcel
import pandas as pd
import polars as pl
import pytest
from pandas.testing import assert_frame_equal as pd_assert_frame_equal
from polars.testing import assert_frame_equal as pl_assert_frame_equal
from utils import path_for_fixture


def test_sheet_with_mixed_dtypes() -> None:
excel_reader = fastexcel.read_excel(path_for_fixture("fixture-multi-dtypes-columns.xlsx"))
sheet = excel_reader.load_sheet(0)

expected_data = {
@pytest.fixture
def expected_data() -> dict[str, list[Any]]:
return {
"Employee ID": [
"123456",
"44333",
Expand Down Expand Up @@ -40,6 +42,54 @@ def test_sheet_with_mixed_dtypes() -> None:
"Asset ID": ["84444"] * 7 + ["ABC123"] * 2,
}


def test_sheet_with_mixed_dtypes(expected_data: dict[str, list[Any]]) -> None:
excel_reader = fastexcel.read_excel(path_for_fixture("fixture-multi-dtypes-columns.xlsx"))
sheet = excel_reader.load_sheet(0)

pd_df = sheet.to_pandas()
pd_assert_frame_equal(pd_df, pd.DataFrame(expected_data).astype({"Date": "datetime64[ms]"}))

pl_df = sheet.to_polars()
pl_assert_frame_equal(
pl_df, pl.DataFrame(expected_data, schema_overrides={"Date": pl.Datetime(time_unit="ms")})
)


def test_sheet_with_mixed_dtypes_and_sample_rows(expected_data: dict[str, list[Any]]) -> None:
excel_reader = fastexcel.read_excel(path_for_fixture("fixture-multi-dtypes-columns.xlsx"))

# Since we skip rows here, the dtypes should be correctly guessed, even if we only check 5 rows
sheet = excel_reader.load_sheet(0, schema_sample_rows=5, skip_rows=5)

expected_data_subset = {col_name: values[5:] for col_name, values in expected_data.items()}
pd_df = sheet.to_pandas()
pd_assert_frame_equal(
pd_df, pd.DataFrame(expected_data_subset).astype({"Date": "datetime64[ms]"})
)

pl_df = sheet.to_polars()
pl_assert_frame_equal(
pl_df,
pl.DataFrame(expected_data_subset, schema_overrides={"Date": pl.Datetime(time_unit="ms")}),
)

# Guess the sheet's dtypes on 5 rows only
sheet = excel_reader.load_sheet(0, schema_sample_rows=5)
# String fields should not have been loaded
expected_data["Employee ID"] = [
123456.0,
44333.0,
44333.0,
87878.0,
87878.0,
None,
135967.0,
None,
None,
]
expected_data["Asset ID"] = [84444.0] * 7 + [None] * 2

pd_df = sheet.to_pandas()
pd_assert_frame_equal(pd_df, pd.DataFrame(expected_data).astype({"Date": "datetime64[ms]"}))

Expand Down
26 changes: 21 additions & 5 deletions src/types/excelreader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ impl ExcelReader {
header_row = 0,
column_names = None,
skip_rows = 0,
n_rows = None
n_rows = None,
schema_sample_rows = 1_000,
))]
pub fn load_sheet_by_name(
&mut self,
Expand All @@ -53,6 +54,7 @@ impl ExcelReader {
column_names: Option<Vec<String>>,
skip_rows: usize,
n_rows: Option<usize>,
schema_sample_rows: Option<usize>,
) -> Result<ExcelSheet> {
let range = self
.sheets
Expand All @@ -61,7 +63,13 @@ impl ExcelReader {

let header = Header::new(header_row, column_names);
let pagination = Pagination::new(skip_rows, n_rows, &range)?;
Ok(ExcelSheet::new(name, range, header, pagination))
Ok(ExcelSheet::new(
name,
range,
header,
pagination,
schema_sample_rows,
))
}

#[pyo3(signature = (
Expand All @@ -70,15 +78,17 @@ impl ExcelReader {
header_row = 0,
column_names = None,
skip_rows = 0,
n_rows = None)
)]
n_rows = None,
schema_sample_rows = 1_000,
))]
pub fn load_sheet_by_idx(
&mut self,
idx: usize,
header_row: Option<usize>,
column_names: Option<Vec<String>>,
skip_rows: usize,
n_rows: Option<usize>,
schema_sample_rows: Option<usize>,
) -> Result<ExcelSheet> {
let name = self
.sheet_names
Expand All @@ -98,6 +108,12 @@ impl ExcelReader {

let header = Header::new(header_row, column_names);
let pagination = Pagination::new(skip_rows, n_rows, &range)?;
Ok(ExcelSheet::new(name, range, header, pagination))
Ok(ExcelSheet::new(
name,
range,
header,
pagination,
schema_sample_rows,
))
}
}
14 changes: 13 additions & 1 deletion src/types/excelsheet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ pub(crate) struct ExcelSheet {
height: Option<usize>,
total_height: Option<usize>,
width: Option<usize>,
schema_sample_rows: Option<usize>,
}

impl ExcelSheet {
Expand All @@ -88,12 +89,14 @@ impl ExcelSheet {
data: Range<CalData>,
header: Header,
pagination: Pagination,
schema_sample_rows: Option<usize>,
) -> Self {
ExcelSheet {
name,
header,
pagination,
data,
schema_sample_rows,
height: None,
total_height: None,
width: None,
Expand Down Expand Up @@ -138,6 +141,10 @@ impl ExcelSheet {

upper_bound
}

pub(crate) fn schema_sample_rows(&self) -> &Option<usize> {
&self.schema_sample_rows
}
}

fn create_boolean_array(
Expand Down Expand Up @@ -240,11 +247,16 @@ impl TryFrom<&ExcelSheet> for Schema {
type Error = anyhow::Error;

fn try_from(value: &ExcelSheet) -> Result<Self, Self::Error> {
// Checking how many rows we want to use to determine the dtype for a column. If sample_rows is
// not provided, we sample limit rows, i.e on the entire column
let sample_rows = value.offset() + value.schema_sample_rows().unwrap_or(value.limit());

arrow_schema_from_column_names_and_range(
value.data(),
&value.column_names(),
value.offset(),
value.limit(),
// If sample_rows is higher than the sheet's limit, use the limit instead
std::cmp::min(sample_rows, value.limit()),
)
}
}
Expand Down

0 comments on commit e4a69bc

Please sign in to comment.