Skip to content

Commit

Permalink
refactor: allow dtypes to be specified by name or by index in case al…
Browse files Browse the repository at this point in the history
…l columns are selected

Signed-off-by: Luka Peschke <[email protected]>
  • Loading branch information
lukapeschke committed Mar 4, 2024
1 parent 457e372 commit 94e59d4
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 7 deletions.
17 changes: 13 additions & 4 deletions python/tests/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def test_sheet_with_mixed_dtypes_and_sample_rows(expected_data: dict[str, list[A
)


@pytest.mark.parametrize("dtype_by_index", (True, False))
@pytest.mark.parametrize(
"dtype,expected_data,expected_pd_dtype,expected_pl_dtype",
[
Expand All @@ -123,14 +124,16 @@ def test_sheet_with_mixed_dtypes_and_sample_rows(expected_data: dict[str, list[A
],
)
def test_sheet_with_mixed_dtypes_specify_dtypes(
dtype_by_index: bool,
dtype: fastexcel.DType,
expected_data: list[Any],
expected_pd_dtype: str,
expected_pl_dtype: pl.DataType,
) -> None:
dtypes: fastexcel.DTypeMap = {0: dtype} if dtype_by_index else {"Employee ID": dtype} # type:ignore[dict-item]
excel_reader = fastexcel.read_excel(path_for_fixture("fixture-multi-dtypes-columns.xlsx"))
sheet = excel_reader.load_sheet(0, dtypes={"Employee ID": dtype}, n_rows=5)
assert sheet.dtypes == {"Employee ID": dtype}
sheet = excel_reader.load_sheet(0, dtypes=dtypes, n_rows=5)
assert sheet.dtypes == dtypes

pd_df = sheet.to_pandas()
assert pd_df["Employee ID"].dtype == expected_pd_dtype
Expand All @@ -144,10 +147,13 @@ def test_sheet_with_mixed_dtypes_specify_dtypes(
@pytest.mark.parametrize(
"dtypes,expected,expected_pd_dtype,expected_pl_dtype",
[
(None, datetime(2023, 7, 21), "datetime64", pl.Datetime),
({"Date": "datetime"}, datetime(2023, 7, 21), "datetime64", pl.Datetime),
(None, datetime(2023, 7, 21), "datetime64[ms]", pl.Datetime),
({"Date": "datetime"}, datetime(2023, 7, 21), "datetime64[ms]", pl.Datetime),
({"Date": "date"}, date(2023, 7, 21), "object", pl.Date),
({"Date": "string"}, "2023-07-21 00:00:00", "object", pl.Utf8),
({2: "datetime"}, datetime(2023, 7, 21), "datetime64[ms]", pl.Datetime),
({2: "date"}, date(2023, 7, 21), "object", pl.Date),
({2: "string"}, "2023-07-21 00:00:00", "object", pl.Utf8),
],
)
def test_sheet_datetime_conversion(
Expand All @@ -161,6 +167,9 @@ def test_sheet_datetime_conversion(
sheet = excel_reader.load_sheet(0, dtypes=dtypes)
assert sheet.dtypes == dtypes
pd_df = sheet.to_pandas()
assert pd_df["Date"].dtype == expected_pd_dtype
assert pd_df["Date"].to_list() == [expected] * 9

pl_df = sheet.to_polars()
assert pl_df["Date"].dtype == expected_pl_dtype
assert pl_df["Date"].to_list() == [expected] * 9
22 changes: 19 additions & 3 deletions src/utils/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,31 @@ pub(crate) fn arrow_schema_from_column_names_and_range(
#[allow(clippy::type_complexity)]
let arrow_type_for_column: Box<dyn Fn(usize, &String) -> FastExcelResult<ArrowDataType>> =
match selected_columns {
// In case all columns are selected, or selected by name, we look up the dtype for the column by name, and
// fallback on get_arrow_column_type
SelectedColumns::All | SelectedColumns::ByName(_) => Box::new(|col_idx, col_name| {
// In case all columns are selected, we look up the dtype for the column by name,
// fallback on a lookup by index, and finally on get_arrow_column_type
SelectedColumns::All => Box::new(|col_idx, col_name| match dtypes {
None => get_arrow_column_type(range, row_idx, row_limit, col_idx),
Some(dts) => {
if let Some(dtype_by_name) = dts.dtype_for_col_name(col_name) {
Ok(dtype_by_name.into())
} else if let Some(dtype_by_idx) = dts.dtype_for_col_idx(col_idx) {
Ok(dtype_by_idx.into())
} else {
get_arrow_column_type(range, row_idx, row_limit, col_idx)
}
}
}),
// If columns are selected by name, look up the dtype by name and fallback on
// get_arrow_column_type
SelectedColumns::ByName(_) => Box::new(|col_idx, col_name| {
dtypes
.and_then(|dtypes| dtypes.dtype_for_col_name(col_name))
.map(|dtype| Ok(dtype.into()))
.unwrap_or_else(|| get_arrow_column_type(range, row_idx, row_limit, col_idx))
}),

// If columns are selected by index, look up the dtype by name and fallback on
// get_arrow_column_type
SelectedColumns::ByIndex(_) => Box::new(|col_idx, _col_name| {
dtypes
.and_then(|dtypes| dtypes.dtype_for_col_idx(col_idx))
Expand Down

0 comments on commit 94e59d4

Please sign in to comment.