Skip to content

Commit

Permalink
fix docttest
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 10, 2024
1 parent 6c47e3e commit 3e1dad2
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 8 deletions.
13 changes: 10 additions & 3 deletions py-polars/polars/_utils/construction/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
from polars.polars import PyDataFrame

if TYPE_CHECKING:
from polars import DataFrame, Series
from polars import DataFrame, Expr, Series
from polars._typing import (
Orientation,
PolarsDataType,
Expand Down Expand Up @@ -1212,15 +1212,22 @@ def arrow_to_pydf(
if rechunk:
pydf = pydf.rechunk()

def broadcastable_s(s: Series, name: str) -> Expr:
if s.len() == 1:
return F.lit(s).first().alias(name)
return F.lit(s).alias(name)

reset_order = False
if len(dictionary_cols) > 0:
df = wrap_df(pydf)
df = df.with_columns([F.lit(s).alias(s.name) for s in dictionary_cols.values()])
df = df.with_columns(
[broadcastable_s(s, name) for s in dictionary_cols.values()]
)
reset_order = True

if len(struct_cols) > 0:
df = wrap_df(pydf)
df = df.with_columns([F.lit(s).alias(s.name) for s in struct_cols.values()])
df = df.with_columns([broadcastable_s(s, name) for s in struct_cols.values()])
reset_order = True

if reset_order:
Expand Down
3 changes: 1 addition & 2 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1723,15 +1723,14 @@ def mode(self) -> Expr:
... "b": [1, 1, 2, 2],
... }
... )
>>> df.select(pl.all().mode()) # doctest: +IGNORE_RESULT
>>> df.select(pl.all().mode().first()) # doctest: +IGNORE_RESULT
shape: (2, 2)
┌─────┬─────┐
│ a ┆ b │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 1 ┆ 1 │
│ 1 ┆ 2 │
└─────┴─────┘
"""
return self._from_pyexpr(self._pyexpr.mode())
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -1999,7 +1999,7 @@ def test_add_string() -> None:

def test_df_broadcast() -> None:
df = pl.DataFrame({"a": [1, 2, 3]}, schema_overrides={"a": pl.UInt8})
out = df.with_columns(pl.Series("s", [[1, 2]]))
out = df.with_columns(pl.lit(pl.Series("s", [[1, 2]])).first())
assert out.shape == (3, 2)
assert out.schema == {"a": pl.UInt8, "s": pl.List(pl.Int64)}
assert out.rows() == [(1, [1, 2]), (2, [1, 2]), (3, [1, 2])]
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/functions/test_lit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
def test_lit_list_input(input: list[Any]) -> None:
df = pl.DataFrame({"a": [1, 2]})
result = df.with_columns(pl.lit(input))
result = df.with_columns(pl.lit(input).first())
expected = pl.DataFrame({"a": [1, 2], "literal": [input, input]})
assert_frame_equal(result, expected)

Expand All @@ -41,7 +41,7 @@ def test_lit_list_input(input: list[Any]) -> None:
)
def test_lit_tuple_input(input: tuple[Any, ...]) -> None:
df = pl.DataFrame({"a": [1, 2]})
result = df.with_columns(pl.lit(input))
result = df.with_columns(pl.lit(input).first())

expected = pl.DataFrame({"a": [1, 2], "literal": [list(input), list(input)]})
assert_frame_equal(result, expected)
Expand Down

0 comments on commit 3e1dad2

Please sign in to comment.