diff --git a/crates/polars-core/src/datatypes/any_value.rs b/crates/polars-core/src/datatypes/any_value.rs index c3f5f57e0c68..5721ee2db2a9 100644 --- a/crates/polars-core/src/datatypes/any_value.rs +++ b/crates/polars-core/src/datatypes/any_value.rs @@ -902,6 +902,8 @@ impl<'a> AnyValue<'a> { #[cfg(feature = "dtype-time")] Time(v) => Time(v), List(v) => List(v), + #[cfg(feature = "dtype-array")] + Array(s, size) => Array(s, size), String(v) => StringOwned(PlSmallStr::from_str(v)), StringOwned(v) => StringOwned(v), Binary(v) => BinaryOwned(v.to_vec()), diff --git a/crates/polars-core/src/scalar/mod.rs b/crates/polars-core/src/scalar/mod.rs index ac7a946ebebc..3220e3468999 100644 --- a/crates/polars-core/src/scalar/mod.rs +++ b/crates/polars-core/src/scalar/mod.rs @@ -1,11 +1,14 @@ pub mod reduce; use polars_utils::pl_str::PlSmallStr; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; use crate::datatypes::{AnyValue, DataType}; use crate::prelude::Series; -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Scalar { dtype: DataType, value: AnyValue<'static>, diff --git a/crates/polars-expr/src/expressions/literal.rs b/crates/polars-expr/src/expressions/literal.rs index ad2e73cd8f70..88303ff31697 100644 --- a/crates/polars-expr/src/expressions/literal.rs +++ b/crates/polars-expr/src/expressions/literal.rs @@ -95,6 +95,7 @@ impl PhysicalExpr for LiteralExpr { .into_time() .into_series(), Series(series) => series.deref().clone(), + OtherScalar(s) => s.clone().into_series(get_literal_name().clone()), lv @ (Int(_) | Float(_) | StrCat(_)) => polars_core::prelude::Series::from_any_values( get_literal_name().clone(), &[lv.to_any_value().unwrap()], diff --git a/crates/polars-plan/src/plans/lit.rs b/crates/polars-plan/src/plans/lit.rs index c44fc3fe8147..b48896ae26d4 100644 --- a/crates/polars-plan/src/plans/lit.rs +++ b/crates/polars-plan/src/plans/lit.rs @@ -62,6 +62,7 @@ pub enum LiteralValue { #[cfg(feature = "dtype-time")] Time(i64), Series(SpecialEq), + OtherScalar(Scalar), // Used for dynamic languages Float(f64), // Used for dynamic languages @@ -135,7 +136,7 @@ impl LiteralValue { DateTime(v, tu, tz) => AnyValue::Datetime(*v, *tu, tz), #[cfg(feature = "dtype-time")] Time(v) => AnyValue::Time(*v), - Series(s) => AnyValue::List(s.0.clone().into_series()), + Series(_) => return None, Int(v) => materialize_dyn_int(*v), Float(v) => AnyValue::Float64(*v), StrCat(v) => AnyValue::String(v), @@ -174,6 +175,7 @@ impl LiteralValue { } }, Binary(v) => AnyValue::Binary(v), + OtherScalar(s) => s.value().clone(), }; Some(av) } @@ -214,6 +216,7 @@ impl LiteralValue { LiteralValue::Int(v) => DataType::Unknown(UnknownKind::Int(*v)), LiteralValue::Float(_) => DataType::Unknown(UnknownKind::Float), LiteralValue::StrCat(_) => DataType::Unknown(UnknownKind::Str), + LiteralValue::OtherScalar(s) => s.dtype().clone(), } } @@ -469,6 +472,12 @@ impl Literal for LiteralValue { } } +impl Literal for Scalar { + fn lit(self) -> Expr { + Expr::Literal(LiteralValue::OtherScalar(self)) + } +} + /// Create a Literal Expression from `L`. A literal expression behaves like a column that contains a single distinct /// value. /// diff --git a/crates/polars-python/src/functions/lazy.rs b/crates/polars-python/src/functions/lazy.rs index 108aaf2121b1..2d39bcdbdc09 100644 --- a/crates/polars-python/src/functions/lazy.rs +++ b/crates/polars-python/src/functions/lazy.rs @@ -409,7 +409,7 @@ pub fn nth(n: i64) -> PyExpr { } #[pyfunction] -pub fn lit(value: &Bound<'_, PyAny>, allow_object: bool) -> PyResult { +pub fn lit(value: &Bound<'_, PyAny>, allow_object: bool, is_scalar: bool) -> PyResult { if value.is_instance_of::() { let val = value.extract::().unwrap(); Ok(dsl::lit(val).into()) @@ -425,7 +425,16 @@ pub fn lit(value: &Bound<'_, PyAny>, allow_object: bool) -> PyResult { } else if let Ok(pystr) = value.downcast::() { Ok(dsl::lit(pystr.to_string()).into()) } else if let Ok(series) = value.extract::() { - Ok(dsl::lit(series.series).into()) + let s = series.series; + if is_scalar { + let av = s + .get(0) + .map_err(|_| PyValueError::new_err("expected at least 1 value"))?; + let av = av.into_static().map_err(PyPolarsErr::from)?; + Ok(dsl::lit(Scalar::new(s.dtype().clone(), av)).into()) + } else { + Ok(dsl::lit(s).into()) + } } else if value.is_none() { Ok(dsl::lit(Null {}).into()) } else if let Ok(value) = value.downcast::() { diff --git a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs index ce95204056bd..f5ea2455a8b4 100644 --- a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs +++ b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs @@ -561,6 +561,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { }, Binary(_) => return Err(PyNotImplementedError::new_err("binary literal")), Range { .. } => return Err(PyNotImplementedError::new_err("range literal")), + OtherScalar { .. } => return Err(PyNotImplementedError::new_err("scalar literal")), Date(..) | DateTime(..) | Decimal(..) => Literal { value: Wrap(lit.to_any_value().unwrap()).to_object(py), dtype, diff --git a/py-polars/polars/functions/lit.py b/py-polars/polars/functions/lit.py index 8853963cbeed..e1a6222deb15 100644 --- a/py-polars/polars/functions/lit.py +++ b/py-polars/polars/functions/lit.py @@ -73,7 +73,7 @@ def lit( if isinstance(value, datetime): if dtype == Date: - return wrap_expr(plr.lit(value.date(), allow_object=False)) + return wrap_expr(plr.lit(value.date(), allow_object=False, is_scalar=True)) # parse time unit if dtype is not None and (tu := getattr(dtype, "time_unit", "us")) is not None: @@ -102,7 +102,9 @@ def lit( raise TypeError(msg) dt_utc = value.replace(tzinfo=timezone.utc) - expr = wrap_expr(plr.lit(dt_utc, allow_object=False)).cast(Datetime(time_unit)) + expr = wrap_expr(plr.lit(dt_utc, allow_object=False, is_scalar=True)).cast( + Datetime(time_unit) + ) if tz is not None: expr = expr.dt.replace_time_zone( tz, ambiguous="earliest" if value.fold == 0 else "latest" @@ -110,36 +112,42 @@ def lit( return expr elif isinstance(value, timedelta): - expr = wrap_expr(plr.lit(value, allow_object=False)) + expr = wrap_expr(plr.lit(value, allow_object=False, is_scalar=True)) if dtype is not None and (tu := getattr(dtype, "time_unit", None)) is not None: expr = expr.cast(Duration(tu)) return expr elif isinstance(value, time): - return wrap_expr(plr.lit(value, allow_object=False)) + return wrap_expr(plr.lit(value, allow_object=False, is_scalar=True)) elif isinstance(value, date): if dtype == Datetime: time_unit = getattr(dtype, "time_unit", "us") or "us" dt_utc = datetime(value.year, value.month, value.day) - expr = wrap_expr(plr.lit(dt_utc, allow_object=False)).cast( + expr = wrap_expr(plr.lit(dt_utc, allow_object=False, is_scalar=True)).cast( Datetime(time_unit) ) if (time_zone := getattr(dtype, "time_zone", None)) is not None: expr = expr.dt.replace_time_zone(str(time_zone)) return expr else: - return wrap_expr(plr.lit(value, allow_object=False)) + return wrap_expr(plr.lit(value, allow_object=False, is_scalar=True)) elif isinstance(value, pl.Series): value = value._s - return wrap_expr(plr.lit(value, allow_object)) + return wrap_expr(plr.lit(value, allow_object, is_scalar=False)) elif _check_for_numpy(value) and isinstance(value, np.ndarray): return lit(pl.Series("literal", value, dtype=dtype)) elif isinstance(value, (list, tuple)): - return lit(pl.Series("literal", [value], dtype=dtype)) + return wrap_expr( + plr.lit( + pl.Series("literal", [value], dtype=dtype)._s, + allow_object, + is_scalar=True, + ) + ) elif isinstance(value, enum.Enum): lit_value = value.value @@ -148,7 +156,7 @@ def lit( return lit(lit_value, dtype=dtype) if dtype: - return wrap_expr(plr.lit(value, allow_object)).cast(dtype) + return wrap_expr(plr.lit(value, allow_object, is_scalar=True)).cast(dtype) try: # numpy literals like np.float32(0) have item/dtype @@ -171,4 +179,4 @@ def lit( except AttributeError: item = value - return wrap_expr(plr.lit(item, allow_object)) + return wrap_expr(plr.lit(item, allow_object, is_scalar=True)) diff --git a/py-polars/tests/unit/expr/test_literal.py b/py-polars/tests/unit/expr/test_literal.py new file mode 100644 index 000000000000..31567e7bc116 --- /dev/null +++ b/py-polars/tests/unit/expr/test_literal.py @@ -0,0 +1,21 @@ +import polars as pl + + +def test_literal_scalar_list_18686() -> None: + df = pl.DataFrame({"column1": [1, 2], "column2": ["A", "B"]}) + out = df.with_columns(lit1=pl.lit([]).cast(pl.List(pl.String)), lit2=pl.lit([])) + + assert out.to_dict(as_series=False) == { + "column1": [1, 2], + "column2": ["A", "B"], + "lit1": [[], []], + "lit2": [[], []], + } + assert out.schema == pl.Schema( + [ + ("column1", pl.Int64), + ("column2", pl.String), + ("lit1", pl.List(pl.String)), + ("lit2", pl.List(pl.Null)), + ] + )