Skip to content

Commit

Permalink
feat: add an option to disable automatic type coercion (#248)
Browse files Browse the repository at this point in the history
* fix: make test.py executable

Signed-off-by: Luka Peschke <[email protected]>

* feat: add an option to disable automatic type coercion

closes #247

Signed-off-by: Luka Peschke <[email protected]>

---------

Signed-off-by: Luka Peschke <[email protected]>
  • Loading branch information
lukapeschke authored Jul 4, 2024
1 parent f962672 commit 12e8628
Show file tree
Hide file tree
Showing 9 changed files with 210 additions and 13 deletions.
13 changes: 13 additions & 0 deletions python/fastexcel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def load_sheet(
skip_rows: int = 0,
n_rows: int | None = None,
schema_sample_rows: int | None = 1_000,
dtype_coercion: Literal["coerce", "strict"] = "coerce",
use_columns: list[str] | list[int] | str | Callable[[ColumnInfo], bool] | None = None,
dtypes: DTypeMap | None = None,
) -> ExcelSheet:
Expand All @@ -146,6 +147,11 @@ def load_sheet(
:param schema_sample_rows: Specifies how many rows should be used to determine
the dtype of a column.
If `None`, all rows will be used.
:param dtype_coercion: Specifies how type coercion should behave. `coerce` (the default)
will try to coerce different dtypes in a column to the same one,
whereas `strict` will raise an error in case a column contains
several dtypes. Note that this only applies to columns whose dtype
is guessed, i.e. not specified via `dtypes`.
:param use_columns: Specifies the columns to use. Can either be:
- `None` to select all columns
- A list of strings and ints, the column names and/or indices
Expand All @@ -165,6 +171,7 @@ def load_sheet(
skip_rows=skip_rows,
n_rows=n_rows,
schema_sample_rows=schema_sample_rows,
dtype_coercion=dtype_coercion,
use_columns=use_columns,
dtypes=dtypes,
eager=False,
Expand All @@ -180,6 +187,7 @@ def load_sheet_eager(
skip_rows: int = 0,
n_rows: int | None = None,
schema_sample_rows: int | None = 1_000,
dtype_coercion: Literal["coerce", "strict"] = "coerce",
use_columns: list[str] | list[int] | str | None = None,
dtypes: DTypeMap | None = None,
) -> pa.RecordBatch:
Expand All @@ -197,6 +205,7 @@ def load_sheet_eager(
skip_rows=skip_rows,
n_rows=n_rows,
schema_sample_rows=schema_sample_rows,
dtype_coercion=dtype_coercion,
use_columns=use_columns,
dtypes=dtypes,
eager=True,
Expand All @@ -211,6 +220,7 @@ def load_sheet_by_name(
skip_rows: int = 0,
n_rows: int | None = None,
schema_sample_rows: int | None = 1_000,
dtype_coercion: Literal["coerce", "strict"] = "coerce",
use_columns: list[str] | list[int] | str | Callable[[ColumnInfo], bool] | None = None,
dtypes: DTypeMap | None = None,
) -> ExcelSheet:
Expand All @@ -225,6 +235,7 @@ def load_sheet_by_name(
skip_rows=skip_rows,
n_rows=n_rows,
schema_sample_rows=schema_sample_rows,
dtype_coercion=dtype_coercion,
use_columns=use_columns,
dtypes=dtypes,
)
Expand All @@ -238,6 +249,7 @@ def load_sheet_by_idx(
skip_rows: int = 0,
n_rows: int | None = None,
schema_sample_rows: int | None = 1_000,
dtype_coercion: Literal["coerce", "strict"] = "coerce",
use_columns: list[str] | list[int] | str | Callable[[ColumnInfo], bool] | None = None,
dtypes: DTypeMap | None = None,
) -> ExcelSheet:
Expand All @@ -252,6 +264,7 @@ def load_sheet_by_idx(
skip_rows=skip_rows,
n_rows=n_rows,
schema_sample_rows=schema_sample_rows,
dtype_coercion=dtype_coercion,
use_columns=use_columns,
dtypes=dtypes,
)
Expand Down
2 changes: 2 additions & 0 deletions python/fastexcel/_fastexcel.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class _ExcelReader:
skip_rows: int = 0,
n_rows: int | None = None,
schema_sample_rows: int | None = 1_000,
dtype_coercion: Literal["coerce", "strict"] = "coerce",
use_columns: list[str] | list[int] | str | Callable[[ColumnInfo], bool] | None = None,
dtypes: DTypeMap | None = None,
eager: Literal[False] = ...,
Expand All @@ -86,6 +87,7 @@ class _ExcelReader:
skip_rows: int = 0,
n_rows: int | None = None,
schema_sample_rows: int | None = 1_000,
dtype_coercion: Literal["coerce", "strict"] = "coerce",
use_columns: list[str] | list[int] | str | None = None,
dtypes: DTypeMap | None = None,
eager: Literal[True] = ...,
Expand Down
66 changes: 65 additions & 1 deletion python/tests/test_dtypes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

from datetime import date, datetime
from typing import Any
from typing import Any, Literal

import fastexcel
import numpy as np
import pandas as pd
import polars as pl
import pytest
Expand Down Expand Up @@ -190,3 +191,66 @@ def test_sheet_datetime_conversion(
pl_df = sheet.to_polars()
assert pl_df["Date"].dtype == expected_pl_dtype
assert pl_df["Date"].to_list() == [expected] * 9


@pytest.mark.parametrize("eager", [True, False])
@pytest.mark.parametrize("dtype_coercion", ["coerce", None])
def test_dtype_coercion_behavior__coerce(
dtype_coercion: Literal["coerce"] | None, eager: bool
) -> None:
excel_reader = fastexcel.read_excel(path_for_fixture("fixture-multi-dtypes-columns.xlsx"))

kwargs = {"dtype_coercion": dtype_coercion} if dtype_coercion else {}
sheet = (
excel_reader.load_sheet_eager(0, **kwargs) # type:ignore[arg-type]
if eager
else excel_reader.load_sheet(0, **kwargs).to_arrow() # type:ignore[arg-type]
)

pd_df = sheet.to_pandas()
assert pd_df["Mixed dates"].dtype == "object"
assert pd_df["Mixed dates"].to_list() == ["2023-07-21 00:00:00"] * 6 + ["July 23rd"] * 3

pl_df = pl.from_arrow(data=sheet)
assert isinstance(pl_df, pl.DataFrame)
assert pl_df["Mixed dates"].dtype == pl.Utf8
assert pl_df["Mixed dates"].to_list() == ["2023-07-21 00:00:00"] * 6 + ["July 23rd"] * 3


@pytest.mark.parametrize("eager", [True, False])
def test_dtype_coercion_behavior__strict_sampling_eveything(eager: bool) -> None:
excel_reader = fastexcel.read_excel(path_for_fixture("fixture-multi-dtypes-columns.xlsx"))

with pytest.raises(
fastexcel.UnsupportedColumnTypeCombinationError, match="type coercion is strict"
):
if eager:
excel_reader.load_sheet_eager(0, dtype_coercion="strict")
else:
excel_reader.load_sheet(0, dtype_coercion="strict").to_arrow()


@pytest.mark.parametrize("eager", [True, False])
def test_dtype_coercion_behavior__strict_sampling_limit(eager: bool) -> None:
excel_reader = fastexcel.read_excel(path_for_fixture("fixture-multi-dtypes-columns.xlsx"))

sheet = (
excel_reader.load_sheet_eager(0, dtype_coercion="strict", schema_sample_rows=5)
if eager
else excel_reader.load_sheet(0, dtype_coercion="strict", schema_sample_rows=5).to_arrow()
)

pd_df = sheet.to_pandas()
assert pd_df["Mixed dates"].dtype == "datetime64[ms]"
assert (
pd_df["Mixed dates"].to_list() == [pd.Timestamp("2023-07-21 00:00:00")] * 6 + [pd.NaT] * 3
)
assert pd_df["Asset ID"].dtype == "float64"
assert pd_df["Asset ID"].replace(np.nan, None).to_list() == [84444.0] * 7 + [None] * 2

pl_df = pl.from_arrow(data=sheet)
assert isinstance(pl_df, pl.DataFrame)
assert pl_df["Mixed dates"].dtype == pl.Datetime
assert pl_df["Mixed dates"].to_list() == [datetime(2023, 7, 21)] * 6 + [None] * 3
assert pl_df["Asset ID"].dtype == pl.Float64
assert pl_df["Asset ID"].to_list() == [84444.0] * 7 + [None] * 2
98 changes: 96 additions & 2 deletions src/types/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,41 @@ impl From<&DType> for ArrowDataType {
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)]
pub(crate) enum DTypeCoercion {
Coerce,
Strict,
}

impl FromStr for DTypeCoercion {
type Err = FastExcelError;

fn from_str(raw_dtype_coercion: &str) -> FastExcelResult<Self> {
match raw_dtype_coercion {
"coerce" => Ok(Self::Coerce),
"strict" => Ok(Self::Strict),
_ => Err(FastExcelErrorKind::InvalidParameters(format!(
"unsupported dtype_coercion: \"{raw_dtype_coercion}\""
))
.into()),
}
}
}

impl FromPyObject<'_> for DTypeCoercion {
fn extract_bound(py_dtype_coercion: &Bound<'_, PyAny>) -> PyResult<Self> {
if let Ok(dtype_coercion_pystr) = py_dtype_coercion.extract::<&PyString>() {
dtype_coercion_pystr.to_str()?.parse()
} else {
Err(FastExcelErrorKind::InvalidParameters(format!(
"{py_dtype_coercion:?} cannot be converted to str"
))
.into())
}
.into_pyresult()
}
}

/// All the possible string values that should be considered as NULL
const NULL_STRING_VALUES: [&str; 19] = [
"", "#N/A", "#N/A N/A", "#NA", "-1.#IND", "-1.#QNAN", "-NaN", "-nan", "1.#IND", "1.#QNAN",
Expand Down Expand Up @@ -203,6 +238,7 @@ pub(crate) fn get_dtype_for_column<DT: CellType + Debug + DataType>(
start_row: usize,
end_row: usize,
col: usize,
dtype_coercion: &DTypeCoercion,
) -> FastExcelResult<DType> {
let mut column_types = (start_row..end_row)
.map(|row| get_cell_dtype(data, row, col))
Expand All @@ -214,6 +250,14 @@ pub(crate) fn get_dtype_for_column<DT: CellType + Debug + DataType>(
if column_types.is_empty() {
// If no type apart from NULL was found, it's a NULL column
Ok(DType::Null)
} else if matches!(dtype_coercion, &DTypeCoercion::Strict) && column_types.len() != 1 {
// If dtype coercion is strict and we do not have a single dtype, it's an error
Err(
FastExcelErrorKind::UnsupportedColumnTypeCombination(format!(
"type coercion is strict and column contains {column_types:?}"
))
.into(),
)
} else if column_types.len() == 1 {
// If a single non-null type was found, return it
Ok(column_types.into_iter().next().unwrap())
Expand Down Expand Up @@ -288,15 +332,65 @@ mod tests {
#[case(7, 11, DType::Float)]
// int + bool
#[case(10, 12, DType::Int)]
fn get_arrow_column_type_multi_dtype_ok(
fn get_arrow_column_type_multi_dtype_ok_coerce(
range: Range<CalData>,
#[case] start_row: usize,
#[case] end_row: usize,
#[case] expected: DType,
) {
assert_eq!(
get_dtype_for_column(&range, start_row, end_row, 0).unwrap(),
get_dtype_for_column(&range, start_row, end_row, 0, &DTypeCoercion::Coerce).unwrap(),
expected
);
}

#[rstest]
// pure bool
#[case(0, 2, DType::Bool)]
// pure int
#[case(3, 4, DType::Int)]
// pure float
#[case(4, 5, DType::Float)]
// pure string
#[case(5, 6, DType::String)]
// empty + null + int
#[case(6, 9, DType::Int)]
fn get_arrow_column_type_multi_dtype_ok_strict(
range: Range<CalData>,
#[case] start_row: usize,
#[case] end_row: usize,
#[case] expected: DType,
) {
assert_eq!(
get_dtype_for_column(&range, start_row, end_row, 0, &DTypeCoercion::Strict).unwrap(),
expected
);
}

#[rstest]
// pure int + float
#[case(3, 5)]
// float + string
#[case(4, 6)]
// int + float + string
#[case(3, 6)]
// null + int + float + string + empty + null
#[case(2, 8)]
// int + float + null
#[case(7, 10)]
// int + float + bool + null
#[case(7, 11)]
// int + bool
#[case(10, 12)]
fn get_arrow_column_type_multi_dtype_ko_strict(
range: Range<CalData>,
#[case] start_row: usize,
#[case] end_row: usize,
) {
let result = get_dtype_for_column(&range, start_row, end_row, 0, &DTypeCoercion::Strict);
assert!(matches!(
result.unwrap_err().kind,
FastExcelErrorKind::UnsupportedColumnTypeCombination(_)
));
}
}
13 changes: 12 additions & 1 deletion src/types/python/excelreader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ use crate::{
error::{
py_errors::IntoPyResult, ErrorContext, FastExcelError, FastExcelErrorKind, FastExcelResult,
},
types::{dtype::DTypeMap, idx_or_name::IdxOrName},
types::{
dtype::{DTypeCoercion, DTypeMap},
idx_or_name::IdxOrName,
},
};

use crate::utils::schema::get_schema_sample_rows;
Expand Down Expand Up @@ -108,6 +111,7 @@ impl ExcelReader {
sample_rows: Option<usize>,
selected_columns: &SelectedColumns,
dtypes: Option<&DTypeMap>,
dtype_coercion: &DTypeCoercion,
) -> FastExcelResult<RecordBatch> {
let offset = header.offset() + pagination.offset();
let limit = {
Expand All @@ -129,6 +133,7 @@ impl ExcelReader {
offset,
sample_rows_limit,
dtypes,
dtype_coercion,
)?;

let fields = available_columns
Expand All @@ -150,6 +155,7 @@ impl ExcelReader {
skip_rows: usize,
n_rows: Option<usize>,
schema_sample_rows: Option<usize>,
dtype_coercion: DTypeCoercion,
use_columns: Option<&Bound<'_, PyAny>>,
dtypes: Option<DTypeMap>,
eager: bool,
Expand All @@ -167,6 +173,7 @@ impl ExcelReader {
schema_sample_rows,
&selected_columns,
dtypes.as_ref(),
&dtype_coercion,
)
.into_pyresult()
.and_then(|rb| rb.to_pyarrow(py))
Expand All @@ -179,6 +186,7 @@ impl ExcelReader {
header,
pagination,
schema_sample_rows,
dtype_coercion,
selected_columns,
dtypes,
)
Expand Down Expand Up @@ -224,6 +232,7 @@ impl ExcelReader {
skip_rows = 0,
n_rows = None,
schema_sample_rows = 1_000,
dtype_coercion = DTypeCoercion::Coerce,
use_columns = None,
dtypes = None,
eager = false,
Expand All @@ -237,6 +246,7 @@ impl ExcelReader {
skip_rows: usize,
n_rows: Option<usize>,
schema_sample_rows: Option<usize>,
dtype_coercion: DTypeCoercion,
use_columns: Option<&Bound<'_, PyAny>>,
dtypes: Option<DTypeMap>,
eager: bool,
Expand Down Expand Up @@ -278,6 +288,7 @@ impl ExcelReader {
skip_rows,
n_rows,
schema_sample_rows,
dtype_coercion,
use_columns,
dtypes,
eager,
Expand Down
Loading

0 comments on commit 12e8628

Please sign in to comment.