Skip to content

Commit

Permalink
fix: Fix scalar literals (#18707)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Sep 12, 2024
1 parent b0145cc commit a66532d
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 14 deletions.
2 changes: 2 additions & 0 deletions crates/polars-core/src/datatypes/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,8 @@ impl<'a> AnyValue<'a> {
#[cfg(feature = "dtype-time")]
Time(v) => Time(v),
List(v) => List(v),
#[cfg(feature = "dtype-array")]
Array(s, size) => Array(s, size),
String(v) => StringOwned(PlSmallStr::from_str(v)),
StringOwned(v) => StringOwned(v),
Binary(v) => BinaryOwned(v.to_vec()),
Expand Down
5 changes: 4 additions & 1 deletion crates/polars-core/src/scalar/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
pub mod reduce;

use polars_utils::pl_str::PlSmallStr;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

use crate::datatypes::{AnyValue, DataType};
use crate::prelude::Series;

#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Scalar {
dtype: DataType,
value: AnyValue<'static>,
Expand Down
1 change: 1 addition & 0 deletions crates/polars-expr/src/expressions/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ impl PhysicalExpr for LiteralExpr {
.into_time()
.into_series(),
Series(series) => series.deref().clone(),
OtherScalar(s) => s.clone().into_series(get_literal_name().clone()),
lv @ (Int(_) | Float(_) | StrCat(_)) => polars_core::prelude::Series::from_any_values(
get_literal_name().clone(),
&[lv.to_any_value().unwrap()],
Expand Down
11 changes: 10 additions & 1 deletion crates/polars-plan/src/plans/lit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ pub enum LiteralValue {
#[cfg(feature = "dtype-time")]
Time(i64),
Series(SpecialEq<Series>),
OtherScalar(Scalar),
// Used for dynamic languages
Float(f64),
// Used for dynamic languages
Expand Down Expand Up @@ -135,7 +136,7 @@ impl LiteralValue {
DateTime(v, tu, tz) => AnyValue::Datetime(*v, *tu, tz),
#[cfg(feature = "dtype-time")]
Time(v) => AnyValue::Time(*v),
Series(s) => AnyValue::List(s.0.clone().into_series()),
Series(_) => return None,
Int(v) => materialize_dyn_int(*v),
Float(v) => AnyValue::Float64(*v),
StrCat(v) => AnyValue::String(v),
Expand Down Expand Up @@ -174,6 +175,7 @@ impl LiteralValue {
}
},
Binary(v) => AnyValue::Binary(v),
OtherScalar(s) => s.value().clone(),
};
Some(av)
}
Expand Down Expand Up @@ -214,6 +216,7 @@ impl LiteralValue {
LiteralValue::Int(v) => DataType::Unknown(UnknownKind::Int(*v)),
LiteralValue::Float(_) => DataType::Unknown(UnknownKind::Float),
LiteralValue::StrCat(_) => DataType::Unknown(UnknownKind::Str),
LiteralValue::OtherScalar(s) => s.dtype().clone(),
}
}

Expand Down Expand Up @@ -469,6 +472,12 @@ impl Literal for LiteralValue {
}
}

impl Literal for Scalar {
fn lit(self) -> Expr {
Expr::Literal(LiteralValue::OtherScalar(self))
}
}

/// Create a Literal Expression from `L`. A literal expression behaves like a column that contains a single distinct
/// value.
///
Expand Down
13 changes: 11 additions & 2 deletions crates/polars-python/src/functions/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ pub fn nth(n: i64) -> PyExpr {
}

#[pyfunction]
pub fn lit(value: &Bound<'_, PyAny>, allow_object: bool) -> PyResult<PyExpr> {
pub fn lit(value: &Bound<'_, PyAny>, allow_object: bool, is_scalar: bool) -> PyResult<PyExpr> {
if value.is_instance_of::<PyBool>() {
let val = value.extract::<bool>().unwrap();
Ok(dsl::lit(val).into())
Expand All @@ -425,7 +425,16 @@ pub fn lit(value: &Bound<'_, PyAny>, allow_object: bool) -> PyResult<PyExpr> {
} else if let Ok(pystr) = value.downcast::<PyString>() {
Ok(dsl::lit(pystr.to_string()).into())
} else if let Ok(series) = value.extract::<PySeries>() {
Ok(dsl::lit(series.series).into())
let s = series.series;
if is_scalar {
let av = s
.get(0)
.map_err(|_| PyValueError::new_err("expected at least 1 value"))?;
let av = av.into_static().map_err(PyPolarsErr::from)?;
Ok(dsl::lit(Scalar::new(s.dtype().clone(), av)).into())
} else {
Ok(dsl::lit(s).into())
}
} else if value.is_none() {
Ok(dsl::lit(Null {}).into())
} else if let Ok(value) = value.downcast::<PyBytes>() {
Expand Down
1 change: 1 addition & 0 deletions crates/polars-python/src/lazyframe/visitor/expr_nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
},
Binary(_) => return Err(PyNotImplementedError::new_err("binary literal")),
Range { .. } => return Err(PyNotImplementedError::new_err("range literal")),
OtherScalar { .. } => return Err(PyNotImplementedError::new_err("scalar literal")),
Date(..) | DateTime(..) | Decimal(..) => Literal {
value: Wrap(lit.to_any_value().unwrap()).to_object(py),
dtype,
Expand Down
28 changes: 18 additions & 10 deletions py-polars/polars/functions/lit.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def lit(

if isinstance(value, datetime):
if dtype == Date:
return wrap_expr(plr.lit(value.date(), allow_object=False))
return wrap_expr(plr.lit(value.date(), allow_object=False, is_scalar=True))

# parse time unit
if dtype is not None and (tu := getattr(dtype, "time_unit", "us")) is not None:
Expand Down Expand Up @@ -102,44 +102,52 @@ def lit(
raise TypeError(msg)

dt_utc = value.replace(tzinfo=timezone.utc)
expr = wrap_expr(plr.lit(dt_utc, allow_object=False)).cast(Datetime(time_unit))
expr = wrap_expr(plr.lit(dt_utc, allow_object=False, is_scalar=True)).cast(
Datetime(time_unit)
)
if tz is not None:
expr = expr.dt.replace_time_zone(
tz, ambiguous="earliest" if value.fold == 0 else "latest"
)
return expr

elif isinstance(value, timedelta):
expr = wrap_expr(plr.lit(value, allow_object=False))
expr = wrap_expr(plr.lit(value, allow_object=False, is_scalar=True))
if dtype is not None and (tu := getattr(dtype, "time_unit", None)) is not None:
expr = expr.cast(Duration(tu))
return expr

elif isinstance(value, time):
return wrap_expr(plr.lit(value, allow_object=False))
return wrap_expr(plr.lit(value, allow_object=False, is_scalar=True))

elif isinstance(value, date):
if dtype == Datetime:
time_unit = getattr(dtype, "time_unit", "us") or "us"
dt_utc = datetime(value.year, value.month, value.day)
expr = wrap_expr(plr.lit(dt_utc, allow_object=False)).cast(
expr = wrap_expr(plr.lit(dt_utc, allow_object=False, is_scalar=True)).cast(
Datetime(time_unit)
)
if (time_zone := getattr(dtype, "time_zone", None)) is not None:
expr = expr.dt.replace_time_zone(str(time_zone))
return expr
else:
return wrap_expr(plr.lit(value, allow_object=False))
return wrap_expr(plr.lit(value, allow_object=False, is_scalar=True))

elif isinstance(value, pl.Series):
value = value._s
return wrap_expr(plr.lit(value, allow_object))
return wrap_expr(plr.lit(value, allow_object, is_scalar=False))

elif _check_for_numpy(value) and isinstance(value, np.ndarray):
return lit(pl.Series("literal", value, dtype=dtype))

elif isinstance(value, (list, tuple)):
return lit(pl.Series("literal", [value], dtype=dtype))
return wrap_expr(
plr.lit(
pl.Series("literal", [value], dtype=dtype)._s,
allow_object,
is_scalar=True,
)
)

elif isinstance(value, enum.Enum):
lit_value = value.value
Expand All @@ -148,7 +156,7 @@ def lit(
return lit(lit_value, dtype=dtype)

if dtype:
return wrap_expr(plr.lit(value, allow_object)).cast(dtype)
return wrap_expr(plr.lit(value, allow_object, is_scalar=True)).cast(dtype)

try:
# numpy literals like np.float32(0) have item/dtype
Expand All @@ -171,4 +179,4 @@ def lit(
except AttributeError:
item = value

return wrap_expr(plr.lit(item, allow_object))
return wrap_expr(plr.lit(item, allow_object, is_scalar=True))
21 changes: 21 additions & 0 deletions py-polars/tests/unit/expr/test_literal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import polars as pl


def test_literal_scalar_list_18686() -> None:
df = pl.DataFrame({"column1": [1, 2], "column2": ["A", "B"]})
out = df.with_columns(lit1=pl.lit([]).cast(pl.List(pl.String)), lit2=pl.lit([]))

assert out.to_dict(as_series=False) == {
"column1": [1, 2],
"column2": ["A", "B"],
"lit1": [[], []],
"lit2": [[], []],
}
assert out.schema == pl.Schema(
[
("column1", pl.Int64),
("column2", pl.String),
("lit1", pl.List(pl.String)),
("lit2", pl.List(pl.Null)),
]
)

0 comments on commit a66532d

Please sign in to comment.