diff --git a/python/fastexcel/__init__.py b/python/fastexcel/__init__.py index fc65be3..e4becde 100644 --- a/python/fastexcel/__init__.py +++ b/python/fastexcel/__init__.py @@ -1,6 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal + +from typing_extensions import TypeAlias if TYPE_CHECKING: import pandas as pd @@ -27,6 +29,9 @@ ) from ._fastexcel import read_excel as _read_excel +DType = Literal["null", "int", "float", "string", "boolean", "datetime", "date", "duration"] +DTypeMap: TypeAlias = "dict[str, DType] | dict[int, DType]" + class ExcelSheet: """A class representing a single sheet in an Excel File""" @@ -64,6 +69,11 @@ def available_columns(self) -> list[str]: """The columns available for the given sheet""" return self._sheet.available_columns + @property + def dtypes(self) -> DTypeMap | None: + """The dtypes specified for the sheet""" + return self._sheet.dtypes + def to_arrow(self) -> pa.RecordBatch: """Converts the sheet to a pyarrow `RecordBatch`""" return self._sheet.to_arrow() @@ -112,6 +122,7 @@ def load_sheet_by_name( n_rows: int | None = None, schema_sample_rows: int | None = 1_000, use_columns: list[str] | list[int] | str | None = None, + dtypes: DTypeMap | None = None, ) -> ExcelSheet: """Loads a sheet by name. @@ -135,6 +146,8 @@ def load_sheet_by_name( - a string, a comma separated list of Excel column letters and column ranges (e.g. `“A:E”` or `“A,C,E:F”`, which would result in `A,B,C,D,E` and `A,C,E,F`) + :param dtypes: An optional dict of dtypes. Keys can either be indices (in case `use_columns` + is a list of ints or an Excel range), or column names """ return ExcelSheet( self._reader.load_sheet_by_name( @@ -145,6 +158,7 @@ def load_sheet_by_name( n_rows=n_rows, schema_sample_rows=schema_sample_rows, use_columns=use_columns, + dtypes=dtypes, ) ) @@ -158,6 +172,7 @@ def load_sheet_by_idx( n_rows: int | None = None, schema_sample_rows: int | None = 1_000, use_columns: list[str] | list[int] | str | None = None, + dtypes: DTypeMap | None = None, ) -> ExcelSheet: """Loads a sheet by index. @@ -193,6 +208,7 @@ def load_sheet_by_idx( n_rows=n_rows, schema_sample_rows=schema_sample_rows, use_columns=use_columns, + dtypes=dtypes, ) ) @@ -206,6 +222,7 @@ def load_sheet( n_rows: int | None = None, schema_sample_rows: int | None = 1_000, use_columns: list[str] | list[int] | str | None = None, + dtypes: DTypeMap | None = None, ) -> ExcelSheet: """Loads a sheet by name if a string is passed or by index if an integer is passed. @@ -220,6 +237,7 @@ def load_sheet( n_rows=n_rows, schema_sample_rows=schema_sample_rows, use_columns=use_columns, + dtypes=dtypes, ) if isinstance(idx_or_name, int) else self.load_sheet_by_name( @@ -230,6 +248,7 @@ def load_sheet( n_rows=n_rows, schema_sample_rows=schema_sample_rows, use_columns=use_columns, + dtypes=dtypes, ) ) diff --git a/python/fastexcel/_fastexcel.pyi b/python/fastexcel/_fastexcel.pyi index 0e27e59..cd614cc 100644 --- a/python/fastexcel/_fastexcel.pyi +++ b/python/fastexcel/_fastexcel.pyi @@ -1,7 +1,13 @@ from __future__ import annotations +from typing import Literal + import pyarrow as pa +_DType = Literal["null", "int", "float", "string", "boolean", "datetime", "date", "duration"] + +_DTypeMap = dict[str, _DType] | dict[int, _DType] + class _ExcelSheet: @property def name(self) -> str: @@ -24,6 +30,9 @@ class _ExcelSheet: @property def available_columns(self) -> list[str]: """The columns available for the given sheet""" + @property + def dtypes(self) -> _DTypeMap | None: + """The dtypes specified for the sheet""" def to_arrow(self) -> pa.RecordBatch: """Converts the sheet to a pyarrow `RecordBatch`""" @@ -40,6 +49,7 @@ class _ExcelReader: n_rows: int | None = None, schema_sample_rows: int | None = 1_000, use_columns: list[str] | list[int] | str | None = None, + dtypes: _DTypeMap | None = None, ) -> _ExcelSheet: ... def load_sheet_by_idx( self, @@ -51,6 +61,7 @@ class _ExcelReader: n_rows: int | None = None, schema_sample_rows: int | None = 1_000, use_columns: list[str] | list[int] | str | None = None, + dtypes: _DTypeMap | None = None, ) -> _ExcelSheet: ... @property def sheet_names(self) -> list[str]: ... diff --git a/python/tests/test_dtypes.py b/python/tests/test_dtypes.py index 6798de7..cc63f4e 100644 --- a/python/tests/test_dtypes.py +++ b/python/tests/test_dtypes.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import datetime +from datetime import date, datetime from typing import Any import fastexcel @@ -97,3 +97,70 @@ def test_sheet_with_mixed_dtypes_and_sample_rows(expected_data: dict[str, list[A pl_assert_frame_equal( pl_df, pl.DataFrame(expected_data, schema_overrides={"Date": pl.Datetime(time_unit="ms")}) ) + + +@pytest.mark.parametrize( + "dtype,expected_data,expected_pd_dtype,expected_pl_dtype", + [ + ("int", [123456, 44333, 44333, 87878, 87878], "int64", pl.Int64), + ("float", [123456.0, 44333.0, 44333.0, 87878.0, 87878.0], "float64", pl.Float64), + ("string", ["123456", "44333", "44333", "87878", "87878"], "object", pl.Utf8), + ("boolean", [True] * 5, "bool", pl.Boolean), + ( + "datetime", + [datetime(2238, 1, 3)] + [datetime(2021, 5, 17)] * 2 + [datetime(2140, 8, 6)] * 2, + "datetime64[ms]", + pl.Datetime, + ), + ( + "date", + [date(2238, 1, 3)] + [date(2021, 5, 17)] * 2 + [date(2140, 8, 6)] * 2, + "object", + pl.Date, + ), + # conversion to duration not supported yet + ("duration", [pd.NaT] * 5, "timedelta64[ms]", pl.Duration), + ], +) +def test_sheet_with_mixed_dtypes_specify_dtypes( + dtype: fastexcel.DType, + expected_data: list[Any], + expected_pd_dtype: str, + expected_pl_dtype: pl.DataType, +) -> None: + excel_reader = fastexcel.read_excel(path_for_fixture("fixture-multi-dtypes-columns.xlsx")) + sheet = excel_reader.load_sheet(0, dtypes={"Employee ID": dtype}, n_rows=5) + assert sheet.dtypes == {"Employee ID": dtype} + + pd_df = sheet.to_pandas() + assert pd_df["Employee ID"].dtype == expected_pd_dtype + assert pd_df["Employee ID"].to_list() == expected_data + + pl_df = sheet.to_polars() + assert pl_df["Employee ID"].dtype == expected_pl_dtype + assert pl_df["Employee ID"].to_list() == (expected_data if dtype != "duration" else [None] * 5) + + +@pytest.mark.parametrize( + "dtypes,expected,expected_pd_dtype,expected_pl_dtype", + [ + (None, datetime(2023, 7, 21), "datetime64", pl.Datetime), + ({"Date": "datetime"}, datetime(2023, 7, 21), "datetime64", pl.Datetime), + ({"Date": "date"}, date(2023, 7, 21), "object", pl.Date), + ({"Date": "string"}, "2023-07-21 00:00:00", "object", pl.Utf8), + ], +) +def test_sheet_datetime_conversion( + dtypes: fastexcel.DTypeMap | None, + expected: Any, + expected_pd_dtype: str, + expected_pl_dtype: pl.DataType, +) -> None: + excel_reader = fastexcel.read_excel(path_for_fixture("fixture-multi-dtypes-columns.xlsx")) + + sheet = excel_reader.load_sheet(0, dtypes=dtypes) + assert sheet.dtypes == dtypes + pd_df = sheet.to_pandas() + assert pd_df["Date"].to_list() == [expected] * 9 + pl_df = sheet.to_polars() + assert pl_df["Date"].to_list() == [expected] * 9 diff --git a/src/types/dtype.rs b/src/types/dtype.rs index c27f01f..1484bfd 100644 --- a/src/types/dtype.rs +++ b/src/types/dtype.rs @@ -1,7 +1,10 @@ use std::{collections::HashMap, str::FromStr}; use arrow::datatypes::{DataType as ArrowDataType, TimeUnit}; -use pyo3::types::PyDict; +use pyo3::{ + types::{IntoPyDict, PyDict}, + PyObject, Python, ToPyObject, +}; use crate::error::{FastExcelError, FastExcelErrorKind, FastExcelResult}; @@ -38,6 +41,22 @@ impl FromStr for DType { } } +impl ToPyObject for DType { + fn to_object(&self, py: Python<'_>) -> PyObject { + match self { + DType::Null => "null", + DType::Int => "int", + DType::Float => "float", + DType::String => "string", + DType::Bool => "boolean", + DType::DateTime => "datetime", + DType::Date => "date", + DType::Duration => "duration", + } + .to_object(py) + } +} + #[derive(Debug)] pub(crate) enum DTypeMap { ByIndex(HashMap), @@ -125,3 +144,21 @@ impl From<&DType> for ArrowDataType { } } } + +impl ToPyObject for DTypeMap { + fn to_object(&self, py: Python<'_>) -> PyObject { + match self { + DTypeMap::ByIndex(idx_map) => idx_map + .iter() + .map(|(k, v)| (k, v.to_object(py))) + .into_py_dict(py) + .into(), + + DTypeMap::ByName(name_map) => name_map + .iter() + .map(|(k, v)| (k, v.to_object(py))) + .into_py_dict(py) + .into(), + } + } +} diff --git a/src/types/excelsheet.rs b/src/types/excelsheet.rs index 40a390a..6b25d69 100644 --- a/src/types/excelsheet.rs +++ b/src/types/excelsheet.rs @@ -23,7 +23,7 @@ use chrono::NaiveDate; use pyo3::{ prelude::{pyclass, pymethods, PyObject, Python}, types::{PyList, PyString}, - PyAny, PyResult, + PyAny, PyResult, ToPyObject, }; use crate::utils::arrow::arrow_schema_from_column_names_and_range; @@ -439,7 +439,12 @@ fn create_boolean_array( limit: usize, ) -> Arc { Arc::new(BooleanArray::from_iter((offset..limit).map(|row| { - data.get((row, col)).and_then(|cell| cell.get_bool()) + data.get((row, col)).and_then(|cell| match cell { + CalData::Bool(b) => Some(*b), + CalData::Int(i) => Some(*i != 0), + CalData::Float(f) => Some(*f != 0.0), + _ => None, + }) }))) } @@ -450,7 +455,7 @@ fn create_int_array( limit: usize, ) -> Arc { Arc::new(Int64Array::from_iter( - (offset..limit).map(|row| data.get((row, col)).and_then(|cell| cell.get_int())), + (offset..limit).map(|row| data.get((row, col)).and_then(|cell| cell.as_i64())), )) } @@ -479,6 +484,8 @@ fn create_string_array( CalData::String(s) => Some(s.to_string()), CalData::Float(s) => Some(s.to_string()), CalData::Int(s) => Some(s.to_string()), + CalData::DateTime(dt) => dt.as_datetime().map(|dt| dt.to_string()), + CalData::DateTimeIso(dt) => Some(dt.to_string()), _ => None, }) }))) @@ -667,6 +674,11 @@ impl ExcelSheet { PyList::new(py, &self.available_columns) } + #[getter] + pub fn dtypes<'p>(&'p self, py: Python<'p>) -> Option { + self.dtypes.as_ref().map(|dtypes| dtypes.to_object(py)) + } + pub fn to_arrow(&self, py: Python<'_>) -> PyResult { RecordBatch::try_from(self) .with_context(|| format!("could not create RecordBatch from sheet \"{}\"", &self.name)) diff --git a/src/utils/arrow.rs b/src/utils/arrow.rs index c25e8d7..d81258f 100644 --- a/src/utils/arrow.rs +++ b/src/utils/arrow.rs @@ -177,7 +177,6 @@ pub(crate) fn arrow_schema_from_column_names_and_range( SelectedColumns::All => Some(idx), _ => selected_columns.idx_for_column(column_names, name, idx), } { - // let col_type = get_arrow_column_type(range, row_idx, row_limit, col_idx)?; let col_type = arrow_type_for_column(col_idx, name)?; let aliased_name = alias_for_name(name, &existing_names); fields.push(Field::new(&aliased_name, col_type, true));