diff --git a/Cargo.lock b/Cargo.lock index 2509252..6c2cba7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -23,9 +23,9 @@ dependencies = [ [[package]] name = "aho-corasick" -version = "0.7.20" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac" +checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" dependencies = [ "memchr", ] @@ -352,6 +352,7 @@ dependencies = [ "calamine", "chrono", "pyo3", + "rstest", ] [[package]] @@ -375,6 +376,12 @@ dependencies = [ "wasi", ] +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + [[package]] name = "half" version = "2.1.0" @@ -522,9 +529,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.5.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" [[package]] name = "memoffset" @@ -658,9 +665,9 @@ checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5" [[package]] name = "proc-macro2" -version = "1.0.44" +version = "1.0.78" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7bd7356a8122b6c4a24a82b278680c73357984ca2fc79a0f9fa6dea7dced7c58" +checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" dependencies = [ "unicode-ident", ] @@ -712,7 +719,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn", + "syn 1.0.101", ] [[package]] @@ -723,7 +730,7 @@ checksum = "97daff08a4c48320587b5224cc98d609e3c27b6d437315bd40b605c98eeb5918" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.101", ] [[package]] @@ -738,9 +745,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.21" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbe448f377a7d6961e30f5955f9b8d106c3f5e449d493ee1b125c1d43c2b5179" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" dependencies = [ "proc-macro2", ] @@ -756,20 +763,26 @@ dependencies = [ [[package]] name = "regex" -version = "1.7.0" +version = "1.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e076559ef8e241f2ae3479e36f97bd5741c0330689e217ad51ce2c76808b868a" +checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.6.27", + "regex-automata", + "regex-syntax 0.8.2", ] [[package]] -name = "regex-syntax" -version = "0.6.27" +name = "regex-automata" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3f87b73ce11b1619a3c6332f45341e0047173771e8b8b73f87bfeefb7b56244" +checksum = "5bb987efffd3c6d0d8f5f89510bb458559eab11e4f869acb20bf845e016259cd" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax 0.8.2", +] [[package]] name = "regex-syntax" @@ -777,12 +790,66 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a5996294f19bd3aae0453a862ad728f60e6600695733dd5df01da90c54363a3c" +[[package]] +name = "regex-syntax" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" + +[[package]] +name = "relative-path" +version = "1.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e898588f33fdd5b9420719948f9f2a32c922a246964576f71ba7f24f80610fbc" + +[[package]] +name = "rstest" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97eeab2f3c0a199bc4be135c36c924b6590b88c377d416494288c14f2db30199" +dependencies = [ + "rstest_macros", + "rustc_version", +] + +[[package]] +name = "rstest_macros" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d428f8247852f894ee1be110b375111b586d4fa431f6c46e64ba5a0dcccbe605" +dependencies = [ + "cfg-if", + "glob", + "proc-macro2", + "quote", + "regex", + "relative-path", + "rustc_version", + "syn 2.0.48", + "unicode-ident", +] + +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + [[package]] name = "scopeguard" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "semver" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97ed7a9823b74f99c7742f5336af7be5ecd3eeafcb1507d1fa93347b1d589b0" + [[package]] name = "serde" version = "1.0.145" @@ -812,6 +879,17 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "syn" +version = "2.0.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "target-lexicon" version = "0.12.4" @@ -829,9 +907,9 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.4" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcc811dc4066ac62f84f11307873c4850cb653bfa9b1719cee2bd2204a4bc5dd" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] name = "unindent" @@ -872,7 +950,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn", + "syn 1.0.101", "wasm-bindgen-shared", ] @@ -894,7 +972,7 @@ checksum = "07bc0c051dc5f23e307b13285f9d75df86bfdf816c5721e573dec1f9b8aa193c" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.101", "wasm-bindgen-backend", "wasm-bindgen-shared", ] diff --git a/Cargo.toml b/Cargo.toml index 723c455..630a385 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,3 +19,6 @@ version = "40.0.0" # There's a lot of stuff we don't want here, such as serde support default-features = false features = ["pyarrow"] + +[dev-dependencies] +rstest = { version = "0.18.2", default-features = false } diff --git a/Makefile b/Makefile index 2cab815..5bee4ed 100644 --- a/Makefile +++ b/Makefile @@ -7,8 +7,9 @@ format = ruff format python/ *.py mypy = mypy python/ *.py pytest = pytest -v ## Rust -clippy = cargo clippy -fmt = cargo fmt +clippy = cargo clippy +fmt = cargo fmt +cargo-test = cargo test ## Docs pdoc = pdoc -o docs python/fastexcel @@ -38,6 +39,7 @@ prod-install: ./prod_install.sh test: + $(cargo-test) $(pytest) doc: diff --git a/python/fastexcel/__init__.py b/python/fastexcel/__init__.py index 9f7c0ee..46265d3 100644 --- a/python/fastexcel/__init__.py +++ b/python/fastexcel/__init__.py @@ -88,17 +88,23 @@ def load_sheet_by_name( column_names: list[str] | None = None, skip_rows: int = 0, n_rows: int | None = None, + schema_sample_rows: int | None = 1_000, ) -> ExcelSheet: """Loads a sheet by name. :param name: The name of the sheet to load. :param header_row: The index of the row containing the column labels, default index is 0. If `None`, the sheet does not have any column labels. - :param column_names: Overrides headers found in the document. If `column_names` is used, - `header_row` will be ignored. - :param n_rows: Specifies how many rows should be loaded. If `None`, all rows are loaded - :param skip_rows: Specifies how many should be skipped after the header. If `header_row` is - `None`, it skips the number of rows from the sheet's start. + :param column_names: Overrides headers found in the document. + If `column_names` is used, `header_row` will be ignored. + :param n_rows: Specifies how many rows should be loaded. + If `None`, all rows are loaded + :param skip_rows: Specifies how many rows should be skipped after the header. + If `header_row` is `None`, it skips the number of rows from the + start of the sheet. + :param schema_sample_rows: Specifies how many rows should be used to determine + the dtype of a column. + If `None`, all rows will be used. """ return ExcelSheet( self._reader.load_sheet_by_name( @@ -107,6 +113,7 @@ def load_sheet_by_name( column_names=column_names, skip_rows=skip_rows, n_rows=n_rows, + schema_sample_rows=schema_sample_rows, ) ) @@ -118,17 +125,23 @@ def load_sheet_by_idx( column_names: list[str] | None = None, skip_rows: int = 0, n_rows: int | None = None, + schema_sample_rows: int | None = 1_000, ) -> ExcelSheet: """Loads a sheet by index. :param idx: The index (starting at 0) of the sheet to load. :param header_row: The index of the row containing the column labels, default index is 0. If `None`, the sheet does not have any column labels. - :param column_names: Overrides headers found in the document. If `column_names` is used, - `header_row` will be ignored. - :param n_rows: Specifies how many rows should be loaded. If `None`, all rows are loaded - :param skip_rows: Specifies how many should be skipped after the header. If `header_row` is - `None`, it skips the number of rows from the sheet's start. + :param column_names: Overrides headers found in the document. + If `column_names` is used, `header_row` will be ignored. + :param n_rows: Specifies how many rows should be loaded. + If `None`, all rows are loaded + :param skip_rows: Specifies how many rows should be skipped after the header. + If `header_row` is `None`, it skips the number of rows from the + start of the sheet. + :param schema_sample_rows: Specifies how many rows should be used to determine + the dtype of a column. + If `None`, all rows will be used. """ if idx < 0: raise ValueError(f"Expected idx to be > 0, got {idx}") @@ -139,6 +152,7 @@ def load_sheet_by_idx( column_names=column_names, skip_rows=skip_rows, n_rows=n_rows, + schema_sample_rows=schema_sample_rows, ) ) @@ -150,6 +164,7 @@ def load_sheet( column_names: list[str] | None = None, skip_rows: int = 0, n_rows: int | None = None, + schema_sample_rows: int | None = 1_000, ) -> ExcelSheet: """Loads a sheet by name if a string is passed or by index if an integer is passed. @@ -162,6 +177,7 @@ def load_sheet( column_names=column_names, skip_rows=skip_rows, n_rows=n_rows, + schema_sample_rows=schema_sample_rows, ) if isinstance(idx_or_name, int) else self.load_sheet_by_name( @@ -170,6 +186,7 @@ def load_sheet( column_names=column_names, skip_rows=skip_rows, n_rows=n_rows, + schema_sample_rows=schema_sample_rows, ) ) diff --git a/python/fastexcel/_fastexcel.pyi b/python/fastexcel/_fastexcel.pyi index b9fd5eb..26e1841 100644 --- a/python/fastexcel/_fastexcel.pyi +++ b/python/fastexcel/_fastexcel.pyi @@ -32,6 +32,7 @@ class _ExcelReader: column_names: list[str] | None = None, skip_rows: int = 0, n_rows: int | None = None, + schema_sample_rows: int | None = 1_000, ) -> _ExcelSheet: ... def load_sheet_by_idx( self, @@ -41,6 +42,7 @@ class _ExcelReader: column_names: list[str] | None = None, skip_rows: int = 0, n_rows: int | None = None, + schema_sample_rows: int | None = 1_000, ) -> _ExcelSheet: ... def load_sheet( self, @@ -50,6 +52,7 @@ class _ExcelReader: column_names: list[str] | None = None, skip_rows: int = 0, n_rows: int | None = None, + schema_sample_rows: int | None = 1_000, ) -> _ExcelSheet: ... @property def sheet_names(self) -> list[str]: ... diff --git a/python/tests/fixtures/fixture-multi-dtypes-columns.xlsx b/python/tests/fixtures/fixture-multi-dtypes-columns.xlsx new file mode 100644 index 0000000..3d80b51 Binary files /dev/null and b/python/tests/fixtures/fixture-multi-dtypes-columns.xlsx differ diff --git a/python/tests/test_dtypes.py b/python/tests/test_dtypes.py new file mode 100644 index 0000000..6798de7 --- /dev/null +++ b/python/tests/test_dtypes.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any + +import fastexcel +import pandas as pd +import polars as pl +import pytest +from pandas.testing import assert_frame_equal as pd_assert_frame_equal +from polars.testing import assert_frame_equal as pl_assert_frame_equal +from utils import path_for_fixture + + +@pytest.fixture +def expected_data() -> dict[str, list[Any]]: + return { + "Employee ID": [ + "123456", + "44333", + "44333", + "87878", + "87878", + "US00011", + "135967", + "IN86868", + "IN86868", + ], + "Employee Name": [ + "Test1", + "Test2", + "Test2", + "Test3", + "Test3", + "Test4", + "Test5", + "Test6", + "Test6", + ], + "Date": [datetime(2023, 7, 21)] * 9, + "Details": ["Healthcare"] * 7 + ["Something"] * 2, + "Asset ID": ["84444"] * 7 + ["ABC123"] * 2, + } + + +def test_sheet_with_mixed_dtypes(expected_data: dict[str, list[Any]]) -> None: + excel_reader = fastexcel.read_excel(path_for_fixture("fixture-multi-dtypes-columns.xlsx")) + sheet = excel_reader.load_sheet(0) + + pd_df = sheet.to_pandas() + pd_assert_frame_equal(pd_df, pd.DataFrame(expected_data).astype({"Date": "datetime64[ms]"})) + + pl_df = sheet.to_polars() + pl_assert_frame_equal( + pl_df, pl.DataFrame(expected_data, schema_overrides={"Date": pl.Datetime(time_unit="ms")}) + ) + + +def test_sheet_with_mixed_dtypes_and_sample_rows(expected_data: dict[str, list[Any]]) -> None: + excel_reader = fastexcel.read_excel(path_for_fixture("fixture-multi-dtypes-columns.xlsx")) + + # Since we skip rows here, the dtypes should be correctly guessed, even if we only check 5 rows + sheet = excel_reader.load_sheet(0, schema_sample_rows=5, skip_rows=5) + + expected_data_subset = {col_name: values[5:] for col_name, values in expected_data.items()} + pd_df = sheet.to_pandas() + pd_assert_frame_equal( + pd_df, pd.DataFrame(expected_data_subset).astype({"Date": "datetime64[ms]"}) + ) + + pl_df = sheet.to_polars() + pl_assert_frame_equal( + pl_df, + pl.DataFrame(expected_data_subset, schema_overrides={"Date": pl.Datetime(time_unit="ms")}), + ) + + # Guess the sheet's dtypes on 5 rows only + sheet = excel_reader.load_sheet(0, schema_sample_rows=5) + # String fields should not have been loaded + expected_data["Employee ID"] = [ + 123456.0, + 44333.0, + 44333.0, + 87878.0, + 87878.0, + None, + 135967.0, + None, + None, + ] + expected_data["Asset ID"] = [84444.0] * 7 + [None] * 2 + + pd_df = sheet.to_pandas() + pd_assert_frame_equal(pd_df, pd.DataFrame(expected_data).astype({"Date": "datetime64[ms]"})) + + pl_df = sheet.to_polars() + pl_assert_frame_equal( + pl_df, pl.DataFrame(expected_data, schema_overrides={"Date": pl.Datetime(time_unit="ms")}) + ) diff --git a/src/types/excelreader.rs b/src/types/excelreader.rs index 592b235..0c9aadd 100644 --- a/src/types/excelreader.rs +++ b/src/types/excelreader.rs @@ -44,7 +44,8 @@ impl ExcelReader { header_row = 0, column_names = None, skip_rows = 0, - n_rows = None + n_rows = None, + schema_sample_rows = 1_000, ))] pub fn load_sheet_by_name( &mut self, @@ -53,6 +54,7 @@ impl ExcelReader { column_names: Option>, skip_rows: usize, n_rows: Option, + schema_sample_rows: Option, ) -> Result { let range = self .sheets @@ -61,7 +63,13 @@ impl ExcelReader { let header = Header::new(header_row, column_names); let pagination = Pagination::new(skip_rows, n_rows, &range)?; - Ok(ExcelSheet::new(name, range, header, pagination)) + Ok(ExcelSheet::new( + name, + range, + header, + pagination, + schema_sample_rows, + )) } #[pyo3(signature = ( @@ -70,8 +78,9 @@ impl ExcelReader { header_row = 0, column_names = None, skip_rows = 0, - n_rows = None) - )] + n_rows = None, + schema_sample_rows = 1_000, + ))] pub fn load_sheet_by_idx( &mut self, idx: usize, @@ -79,6 +88,7 @@ impl ExcelReader { column_names: Option>, skip_rows: usize, n_rows: Option, + schema_sample_rows: Option, ) -> Result { let name = self .sheet_names @@ -98,6 +108,12 @@ impl ExcelReader { let header = Header::new(header_row, column_names); let pagination = Pagination::new(skip_rows, n_rows, &range)?; - Ok(ExcelSheet::new(name, range, header, pagination)) + Ok(ExcelSheet::new( + name, + range, + header, + pagination, + schema_sample_rows, + )) } } diff --git a/src/types/excelsheet.rs b/src/types/excelsheet.rs index c10897b..e7ec2eb 100644 --- a/src/types/excelsheet.rs +++ b/src/types/excelsheet.rs @@ -76,6 +76,7 @@ pub(crate) struct ExcelSheet { height: Option, total_height: Option, width: Option, + schema_sample_rows: Option, } impl ExcelSheet { @@ -88,12 +89,14 @@ impl ExcelSheet { data: Range, header: Header, pagination: Pagination, + schema_sample_rows: Option, ) -> Self { ExcelSheet { name, header, pagination, data, + schema_sample_rows, height: None, total_height: None, width: None, @@ -110,8 +113,7 @@ impl ExcelSheet { .map(|col_idx| { self.data .get((*row_idx, col_idx)) - .and_then(|data| data.get_string()) - .map(ToOwned::to_owned) + .and_then(|data| data.as_string()) .unwrap_or(format!("__UNNAMED__{col_idx}")) }) .collect(), @@ -139,6 +141,10 @@ impl ExcelSheet { upper_bound } + + pub(crate) fn schema_sample_rows(&self) -> &Option { + &self.schema_sample_rows + } } fn create_boolean_array( @@ -169,9 +175,9 @@ fn create_float_array( offset: usize, limit: usize, ) -> Arc { - Arc::new(Float64Array::from_iter((offset..limit).map(|row| { - data.get((row, col)).and_then(|cell| cell.get_float()) - }))) + Arc::new(Float64Array::from_iter( + (offset..limit).map(|row| data.get((row, col)).and_then(|cell| cell.as_f64())), + )) } fn create_string_array( @@ -181,7 +187,15 @@ fn create_string_array( limit: usize, ) -> Arc { Arc::new(StringArray::from_iter((offset..limit).map(|row| { - data.get((row, col)).and_then(|cell| cell.get_string()) + // NOTE: Not using cell.as_string() here because it matches the String variant last, which + // is slower for columns containing mostly/only strings (which we expect to meet more often than + // mixed dtype columns containing mostly numbers) + data.get((row, col)).and_then(|cell| match cell { + CalData::String(s) => Some(s.to_string()), + CalData::Float(s) => Some(s.to_string()), + CalData::Int(s) => Some(s.to_string()), + _ => None, + }) }))) } @@ -233,10 +247,16 @@ impl TryFrom<&ExcelSheet> for Schema { type Error = anyhow::Error; fn try_from(value: &ExcelSheet) -> Result { + // Checking how many rows we want to use to determine the dtype for a column. If sample_rows is + // not provided, we sample limit rows, i.e on the entire column + let sample_rows = value.offset() + value.schema_sample_rows().unwrap_or(value.limit()); + arrow_schema_from_column_names_and_range( value.data(), &value.column_names(), value.offset(), + // If sample_rows is higher than the sheet's limit, use the limit instead + std::cmp::min(sample_rows, value.limit()), ) } } diff --git a/src/utils/arrow.rs b/src/utils/arrow.rs index 66bbf9c..12b2df9 100644 --- a/src/utils/arrow.rs +++ b/src/utils/arrow.rs @@ -1,8 +1,10 @@ +use std::{collections::HashSet, sync::OnceLock}; + use anyhow::{anyhow, Context, Result}; use arrow::datatypes::{DataType as ArrowDataType, Field, Schema, TimeUnit}; use calamine::{Data as CalData, DataType, Range}; -fn get_arrow_column_type(data: &Range, row: usize, col: usize) -> Result { +fn get_cell_type(data: &Range, row: usize, col: usize) -> Result { let cell = data .get((row, col)) .with_context(|| format!("Could not retrieve data at ({row},{col})"))?; @@ -34,6 +36,70 @@ fn get_arrow_column_type(data: &Range, row: usize, col: usize) -> Resul } } +static FLOAT_TYPES_CELL: OnceLock> = OnceLock::new(); +static INT_TYPES_CELL: OnceLock> = OnceLock::new(); +static STRING_TYPES_CELL: OnceLock> = OnceLock::new(); + +fn float_types() -> &'static HashSet { + FLOAT_TYPES_CELL.get_or_init(|| { + HashSet::from([ + ArrowDataType::Int64, + ArrowDataType::Float64, + ArrowDataType::Boolean, + ]) + }) +} + +fn int_types() -> &'static HashSet { + INT_TYPES_CELL.get_or_init(|| HashSet::from([ArrowDataType::Int64, ArrowDataType::Boolean])) +} + +fn string_types() -> &'static HashSet { + STRING_TYPES_CELL.get_or_init(|| { + HashSet::from([ + ArrowDataType::Int64, + ArrowDataType::Float64, + ArrowDataType::Utf8, + ]) + }) +} + +fn get_arrow_column_type( + data: &Range, + start_row: usize, + end_row: usize, + col: usize, +) -> Result { + let mut column_types = (start_row..end_row) + .map(|row| get_cell_type(data, row, col)) + .collect::>>()?; + + // All columns are nullable anyway so we're not taking Null into account here + column_types.remove(&ArrowDataType::Null); + + if column_types.is_empty() { + // If no type apart from NULL was found, it's a NULL column + Ok(ArrowDataType::Null) + } else if column_types.len() == 1 { + // If a single non-null type was found, return it + Ok(column_types.into_iter().next().unwrap()) + } else if column_types.is_subset(int_types()) { + // If every cell in the column can be converted to an int, return int64 + Ok(ArrowDataType::Int64) + } else if column_types.is_subset(float_types()) { + // If every cell in the column can be converted to a float, return Float64 + Ok(ArrowDataType::Float64) + } else if column_types.is_subset(string_types()) { + // If every cell in the column can be converted to a string, return Utf8 + Ok(ArrowDataType::Utf8) + } else { + // NOTE: Not being too smart about multi-types columns for now + Err(anyhow!( + "could not figure out column type for following type combination: {column_types:?}" + )) + } +} + fn alias_for_name(name: &str, fields: &[Field]) -> String { fn rec(name: &str, fields: &[Field], depth: usize) -> String { let alias = if depth == 0 { @@ -54,13 +120,76 @@ pub(crate) fn arrow_schema_from_column_names_and_range( range: &Range, column_names: &[String], row_idx: usize, + row_limit: usize, ) -> Result { let mut fields = Vec::with_capacity(column_names.len()); for (col_idx, name) in column_names.iter().enumerate() { - let col_type = get_arrow_column_type(range, row_idx, col_idx)?; + let col_type = get_arrow_column_type(range, row_idx, row_limit, col_idx)?; fields.push(Field::new(&alias_for_name(name, &fields), col_type, true)); } Ok(Schema::new(fields)) } + +#[cfg(test)] +mod tests { + use calamine::Cell; + use rstest::{fixture, rstest}; + + use super::*; + + #[fixture] + fn range() -> Range { + Range::from_sparse(vec![ + // First column + Cell::new((0, 0), CalData::Bool(true)), + Cell::new((1, 0), CalData::Bool(false)), + Cell::new((2, 0), CalData::Int(42)), + Cell::new((3, 0), CalData::Float(13.37)), + Cell::new((4, 0), CalData::String("hello".to_string())), + Cell::new((5, 0), CalData::Empty), + Cell::new((6, 0), CalData::Int(12)), + Cell::new((7, 0), CalData::Float(12.21)), + Cell::new((8, 0), CalData::Bool(true)), + Cell::new((9, 0), CalData::Int(1337)), + ]) + } + + #[rstest] + // pure bool + #[case(0, 2, ArrowDataType::Boolean)] + // pure int + #[case(2, 3, ArrowDataType::Int64)] + // pure float + #[case(3, 4, ArrowDataType::Float64)] + // pure string + #[case(4, 5, ArrowDataType::Utf8)] + // pure int + float + #[case(2, 4, ArrowDataType::Float64)] + // float + string + #[case(3, 5, ArrowDataType::Utf8)] + // int + float + string + #[case(2, 5, ArrowDataType::Utf8)] + // int + float + string + empty + #[case(2, 6, ArrowDataType::Utf8)] + // int + null + #[case(5, 7, ArrowDataType::Int64)] + // int + float + null + #[case(5, 8, ArrowDataType::Float64)] + // int + float + bool + null + #[case(5, 9, ArrowDataType::Float64)] + // int + bool + #[case(8, 10, ArrowDataType::Int64)] + fn get_arrow_column_type_multi_dtype_ok( + range: Range, + #[case] start_row: usize, + #[case] end_row: usize, + #[case] expected: ArrowDataType, + ) { + assert_eq!( + get_arrow_column_type(&range, start_row, end_row, 0).unwrap(), + expected + ); + } +}