Skip to content

Commit 94e59d4

Browse files
committed
refactor: allow dtypes to be specified by name or by index in case all columns are selected
Signed-off-by: Luka Peschke <[email protected]>
1 parent 457e372 commit 94e59d4

File tree

2 files changed

+32
-7
lines changed

2 files changed

+32
-7
lines changed

python/tests/test_dtypes.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def test_sheet_with_mixed_dtypes_and_sample_rows(expected_data: dict[str, list[A
9999
)
100100

101101

102+
@pytest.mark.parametrize("dtype_by_index", (True, False))
102103
@pytest.mark.parametrize(
103104
"dtype,expected_data,expected_pd_dtype,expected_pl_dtype",
104105
[
@@ -123,14 +124,16 @@ def test_sheet_with_mixed_dtypes_and_sample_rows(expected_data: dict[str, list[A
123124
],
124125
)
125126
def test_sheet_with_mixed_dtypes_specify_dtypes(
127+
dtype_by_index: bool,
126128
dtype: fastexcel.DType,
127129
expected_data: list[Any],
128130
expected_pd_dtype: str,
129131
expected_pl_dtype: pl.DataType,
130132
) -> None:
133+
dtypes: fastexcel.DTypeMap = {0: dtype} if dtype_by_index else {"Employee ID": dtype} # type:ignore[dict-item]
131134
excel_reader = fastexcel.read_excel(path_for_fixture("fixture-multi-dtypes-columns.xlsx"))
132-
sheet = excel_reader.load_sheet(0, dtypes={"Employee ID": dtype}, n_rows=5)
133-
assert sheet.dtypes == {"Employee ID": dtype}
135+
sheet = excel_reader.load_sheet(0, dtypes=dtypes, n_rows=5)
136+
assert sheet.dtypes == dtypes
134137

135138
pd_df = sheet.to_pandas()
136139
assert pd_df["Employee ID"].dtype == expected_pd_dtype
@@ -144,10 +147,13 @@ def test_sheet_with_mixed_dtypes_specify_dtypes(
144147
@pytest.mark.parametrize(
145148
"dtypes,expected,expected_pd_dtype,expected_pl_dtype",
146149
[
147-
(None, datetime(2023, 7, 21), "datetime64", pl.Datetime),
148-
({"Date": "datetime"}, datetime(2023, 7, 21), "datetime64", pl.Datetime),
150+
(None, datetime(2023, 7, 21), "datetime64[ms]", pl.Datetime),
151+
({"Date": "datetime"}, datetime(2023, 7, 21), "datetime64[ms]", pl.Datetime),
149152
({"Date": "date"}, date(2023, 7, 21), "object", pl.Date),
150153
({"Date": "string"}, "2023-07-21 00:00:00", "object", pl.Utf8),
154+
({2: "datetime"}, datetime(2023, 7, 21), "datetime64[ms]", pl.Datetime),
155+
({2: "date"}, date(2023, 7, 21), "object", pl.Date),
156+
({2: "string"}, "2023-07-21 00:00:00", "object", pl.Utf8),
151157
],
152158
)
153159
def test_sheet_datetime_conversion(
@@ -161,6 +167,9 @@ def test_sheet_datetime_conversion(
161167
sheet = excel_reader.load_sheet(0, dtypes=dtypes)
162168
assert sheet.dtypes == dtypes
163169
pd_df = sheet.to_pandas()
170+
assert pd_df["Date"].dtype == expected_pd_dtype
164171
assert pd_df["Date"].to_list() == [expected] * 9
172+
165173
pl_df = sheet.to_polars()
174+
assert pl_df["Date"].dtype == expected_pl_dtype
166175
assert pl_df["Date"].to_list() == [expected] * 9

src/utils/arrow.rs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,15 +150,31 @@ pub(crate) fn arrow_schema_from_column_names_and_range(
150150
#[allow(clippy::type_complexity)]
151151
let arrow_type_for_column: Box<dyn Fn(usize, &String) -> FastExcelResult<ArrowDataType>> =
152152
match selected_columns {
153-
// In case all columns are selected, or selected by name, we look up the dtype for the column by name, and
154-
// fallback on get_arrow_column_type
155-
SelectedColumns::All | SelectedColumns::ByName(_) => Box::new(|col_idx, col_name| {
153+
// In case all columns are selected, we look up the dtype for the column by name,
154+
// fallback on a lookup by index, and finally on get_arrow_column_type
155+
SelectedColumns::All => Box::new(|col_idx, col_name| match dtypes {
156+
None => get_arrow_column_type(range, row_idx, row_limit, col_idx),
157+
Some(dts) => {
158+
if let Some(dtype_by_name) = dts.dtype_for_col_name(col_name) {
159+
Ok(dtype_by_name.into())
160+
} else if let Some(dtype_by_idx) = dts.dtype_for_col_idx(col_idx) {
161+
Ok(dtype_by_idx.into())
162+
} else {
163+
get_arrow_column_type(range, row_idx, row_limit, col_idx)
164+
}
165+
}
166+
}),
167+
// If columns are selected by name, look up the dtype by name and fallback on
168+
// get_arrow_column_type
169+
SelectedColumns::ByName(_) => Box::new(|col_idx, col_name| {
156170
dtypes
157171
.and_then(|dtypes| dtypes.dtype_for_col_name(col_name))
158172
.map(|dtype| Ok(dtype.into()))
159173
.unwrap_or_else(|| get_arrow_column_type(range, row_idx, row_limit, col_idx))
160174
}),
161175

176+
// If columns are selected by index, look up the dtype by name and fallback on
177+
// get_arrow_column_type
162178
SelectedColumns::ByIndex(_) => Box::new(|col_idx, _col_name| {
163179
dtypes
164180
.and_then(|dtypes| dtypes.dtype_for_col_idx(col_idx))

0 commit comments

Comments
 (0)