Skip to content

Commit

Permalink
fix(python,rust): Align new_columns argument for scan_csv and `read…
Browse files Browse the repository at this point in the history
…_csv` (#11575)
  • Loading branch information
c-peters authored Oct 7, 2023
1 parent 01edba3 commit 90b4c7e
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 10 deletions.
11 changes: 7 additions & 4 deletions py-polars/polars/io/csv/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,8 +799,8 @@ def scan_csv(
Single byte end of line character
new_columns
Provide an explicit list of string column names to use (for example, when
scanning a headerless CSV file). Note that unlike ``read_csv`` it is considered
an error to provide fewer column names than there are columns in the file.
scanning a headerless CSV file). If the given list is shorter than the width of
the DataFrame the remaining columns will have their original name.
raise_if_empty
When there is no data in the source,``NoDataError`` is raised. If this parameter
is set to False, an empty LazyFrame (with no columns) is returned instead.
Expand Down Expand Up @@ -885,8 +885,11 @@ def scan_csv(
dtypes = dict(zip(new_columns, dtypes))

# wrap new column names as a callable
def with_column_names(_cols: list[str]) -> list[str]:
return new_columns # type: ignore[return-value]
def with_column_names(cols: list[str]) -> list[str]:
if len(cols) > len(new_columns):
return new_columns + cols[len(new_columns) :] # type: ignore[operator]
else:
return new_columns # type: ignore[return-value]

_check_arg_is_1byte("separator", separator, can_be_empty=False)
_check_arg_is_1byte("comment_char", comment_char, can_be_empty=False)
Expand Down
5 changes: 3 additions & 2 deletions py-polars/src/lazyframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,9 @@ impl PyLazyFrame {
let new_names = out
.extract::<Vec<String>>(py)
.expect("python function should return List[str]");
assert_eq!(new_names.len(), schema.len(), "The length of the new names list should be equal to the original column length");

polars_ensure!(new_names.len() == schema.len(),
ShapeMismatch: "The length of the new names list should be equal to or less than the original column length",
);
Ok(schema
.iter_dtypes()
.zip(new_names)
Expand Down
13 changes: 13 additions & 0 deletions py-polars/tests/unit/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,6 +1329,19 @@ def test_csv_scan_categorical(tmp_path: Path) -> None:
assert result["x"].dtype == pl.Categorical


@pytest.mark.write_disk()
def test_csv_scan_new_columns_less_than_original_columns(tmp_path: Path) -> None:
tmp_path.mkdir(exist_ok=True)

df = pl.DataFrame({"x": ["A"], "y": ["A"], "z": "A"})

file_path = tmp_path / "test_csv_scan_new_columns.csv"
df.write_csv(file_path)
result = pl.scan_csv(file_path, new_columns=["x_new", "y_new"]).collect()

assert result.columns == ["x_new", "y_new", "z"]


def test_read_csv_chunked() -> None:
"""Check that row count is properly functioning."""
N = 10_000
Expand Down
16 changes: 12 additions & 4 deletions py-polars/tests/unit/io/test_lazy_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pytest

import polars as pl
from polars.exceptions import PolarsPanicError
from polars.testing import assert_frame_equal

if TYPE_CHECKING:
Expand Down Expand Up @@ -150,12 +149,21 @@ def test_scan_csv_schema_new_columns_dtypes(
== df1.select(["sugars", pl.col("calories").cast(pl.Int64)]).rows()
)

# expect same number of column names as there are columns in the file
with pytest.raises(PolarsPanicError, match="should be equal"):
# partially rename columns / overwrite dtypes
df4 = pl.scan_csv(
file_path,
dtypes=[pl.Utf8, pl.Utf8],
new_columns=["category", "calories"],
).collect()
assert df4.dtypes == [pl.Utf8, pl.Utf8, pl.Float64, pl.Int64]
assert df4.columns == ["category", "calories", "fats_g", "sugars_g"]

# cannot have len(new_columns) > len(actual columns)
with pytest.raises(pl.ShapeError):
pl.scan_csv(
file_path,
dtypes=[pl.Utf8, pl.Utf8],
new_columns=["category", "calories"],
new_columns=["category", "calories", "c3", "c4", "c5"],
).collect()

# cannot set both 'new_columns' and 'with_column_names'
Expand Down

0 comments on commit 90b4c7e

Please sign in to comment.