diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index c7c622f96540..aac8155aaed3 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -31,6 +31,7 @@ pub use ndjson::*; pub use parquet::*; use polars_core::prelude::*; use polars_io::RowIndex; +use polars_ops::frame::JoinCoalesce; pub use polars_plan::frame::{AllowedOptimizations, OptState}; use polars_plan::global::FETCH_ROWS; use smartstring::alias::String as SmartString; @@ -1124,7 +1125,7 @@ impl LazyFrame { other, [left_on.into()], [right_on.into()], - JoinArgs::new(JoinType::Outer { coalesce: false }), + JoinArgs::new(JoinType::Outer), ) } @@ -1195,6 +1196,7 @@ impl LazyFrame { .right_on(right_on) .how(args.how) .validate(args.validation) + .coalesce(args.coalesce) .join_nulls(args.join_nulls); if let Some(suffix) = args.suffix { @@ -1764,6 +1766,7 @@ pub struct JoinBuilder { force_parallel: bool, suffix: Option, validation: JoinValidation, + coalesce: JoinCoalesce, join_nulls: bool, } impl JoinBuilder { @@ -1780,6 +1783,7 @@ impl JoinBuilder { join_nulls: false, suffix: None, validation: Default::default(), + coalesce: Default::default(), } } @@ -1851,6 +1855,12 @@ impl JoinBuilder { self } + /// Whether to coalesce join columns. + pub fn coalesce(mut self, coalesce: JoinCoalesce) -> Self { + self.coalesce = coalesce; + self + } + /// Finish builder pub fn finish(self) -> LazyFrame { let mut opt_state = self.lf.opt_state; @@ -1865,6 +1875,7 @@ impl JoinBuilder { suffix: self.suffix, slice: None, join_nulls: self.join_nulls, + coalesce: self.coalesce, }; let lp = self diff --git a/crates/polars-lazy/src/tests/streaming.rs b/crates/polars-lazy/src/tests/streaming.rs index c320c162b3e2..1c51e480636d 100644 --- a/crates/polars-lazy/src/tests/streaming.rs +++ b/crates/polars-lazy/src/tests/streaming.rs @@ -1,3 +1,5 @@ +use polars_ops::frame::JoinCoalesce; + use super::*; fn get_csv_file() -> LazyFrame { @@ -295,7 +297,8 @@ fn test_streaming_partial() -> PolarsResult<()> { .left_on([col("a")]) .right_on([col("a")]) .suffix("_foo") - .how(JoinType::Outer { coalesce: true }) + .how(JoinType::Outer) + .coalesce(JoinCoalesce::CoalesceColumns) .finish(); let q = q.left_join( diff --git a/crates/polars-ops/src/frame/join/args.rs b/crates/polars-ops/src/frame/join/args.rs index 51bbbf9d80fe..148c46ce7953 100644 --- a/crates/polars-ops/src/frame/join/args.rs +++ b/crates/polars-ops/src/frame/join/args.rs @@ -26,6 +26,36 @@ pub struct JoinArgs { pub suffix: Option, pub slice: Option<(i64, usize)>, pub join_nulls: bool, + pub coalesce: JoinCoalesce, +} + +#[derive(Clone, PartialEq, Eq, Debug, Hash, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum JoinCoalesce { + #[default] + JoinSpecific, + CoalesceColumns, + KeepColumns, +} + +impl JoinCoalesce { + pub fn coalesce(&self, join_type: &JoinType) -> bool { + use JoinCoalesce::*; + use JoinType::*; + match join_type { + Left | Inner => { + matches!(self, JoinSpecific | CoalesceColumns) + }, + Outer { .. } => { + matches!(self, CoalesceColumns) + }, + #[cfg(feature = "asof_join")] + AsOf(_) => false, + Cross => false, + #[cfg(feature = "semi_anti_join")] + Semi | Anti => false, + } + } } impl Default for JoinArgs { @@ -36,6 +66,7 @@ impl Default for JoinArgs { suffix: None, slice: None, join_nulls: false, + coalesce: Default::default(), } } } @@ -48,9 +79,15 @@ impl JoinArgs { suffix: None, slice: None, join_nulls: false, + coalesce: Default::default(), } } + pub fn with_coalesce(mut self, coalesce: JoinCoalesce) -> Self { + self.coalesce = coalesce; + self + } + pub fn suffix(&self) -> &str { self.suffix.as_deref().unwrap_or("_right") } @@ -61,9 +98,7 @@ impl JoinArgs { pub enum JoinType { Left, Inner, - Outer { - coalesce: bool, - }, + Outer, #[cfg(feature = "asof_join")] AsOf(AsOfOptions), Cross, @@ -73,18 +108,6 @@ pub enum JoinType { Anti, } -impl JoinType { - pub fn merges_join_keys(&self) -> bool { - match self { - Self::Outer { coalesce } => *coalesce, - // Merges them if they are equal - #[cfg(feature = "asof_join")] - Self::AsOf(_) => false, - _ => true, - } - } -} - impl From for JoinArgs { fn from(value: JoinType) -> Self { JoinArgs::new(value) @@ -116,6 +139,19 @@ impl Debug for JoinType { } } +impl JoinType { + pub fn is_asof(&self) -> bool { + #[cfg(feature = "asof_join")] + { + matches!(self, JoinType::AsOf(_)) + } + #[cfg(not(feature = "asof_join"))] + { + false + } + } +} + #[derive(Copy, Clone, PartialEq, Eq, Default, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum JoinValidation { diff --git a/crates/polars-ops/src/frame/join/hash_join/mod.rs b/crates/polars-ops/src/frame/join/hash_join/mod.rs index 0be95b1aa1cf..f6b1ca773ee4 100644 --- a/crates/polars-ops/src/frame/join/hash_join/mod.rs +++ b/crates/polars-ops/src/frame/join/hash_join/mod.rs @@ -271,9 +271,7 @@ pub trait JoinDispatch: IntoDf { || unsafe { other.take_unchecked(&idx_ca_r) }, ); - let JoinType::Outer { coalesce } = args.how else { - unreachable!() - }; + let coalesce = args.coalesce.coalesce(&JoinType::Outer); let out = _finish_join(df_left, df_right, args.suffix.as_deref()); if coalesce { Ok(_coalesce_outer_join( diff --git a/crates/polars-ops/src/frame/join/mod.rs b/crates/polars-ops/src/frame/join/mod.rs index f3df643de0e8..6a29e2b28c3a 100644 --- a/crates/polars-ops/src/frame/join/mod.rs +++ b/crates/polars-ops/src/frame/join/mod.rs @@ -209,9 +209,7 @@ pub trait DataFrameJoinOps: IntoDf { JoinType::Left => { left_df._left_join_from_series(other, s_left, s_right, args, _verbose, None) }, - JoinType::Outer { .. } => { - left_df._outer_join_from_series(other, s_left, s_right, args) - }, + JoinType::Outer => left_df._outer_join_from_series(other, s_left, s_right, args), #[cfg(feature = "semi_anti_join")] JoinType::Anti => left_df._semi_anti_join_from_series( s_left, @@ -278,13 +276,14 @@ pub trait DataFrameJoinOps: IntoDf { JoinType::Cross => { unreachable!() }, - JoinType::Outer { coalesce } => { + JoinType::Outer => { let names_left = selected_left.iter().map(|s| s.name()).collect::>(); - args.how = JoinType::Outer { coalesce: false }; + let coalesce = args.coalesce; + args.coalesce = JoinCoalesce::KeepColumns; let suffix = args.suffix.clone(); let out = left_df._outer_join_from_series(other, &lhs_keys, &rhs_keys, args); - if coalesce { + if coalesce.coalesce(&JoinType::Outer) { Ok(_coalesce_outer_join( out?, &names_left, @@ -411,12 +410,7 @@ pub trait DataFrameJoinOps: IntoDf { I: IntoIterator, S: AsRef, { - self.join( - other, - left_on, - right_on, - JoinArgs::new(JoinType::Outer { coalesce: false }), - ) + self.join(other, left_on, right_on, JoinArgs::new(JoinType::Outer)) } } diff --git a/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs b/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs index 864020d1a8a1..1fa7ce58a152 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs @@ -5,6 +5,7 @@ use hashbrown::hash_map::RawEntryMut; use polars_core::export::ahash::RandomState; use polars_core::prelude::*; use polars_core::utils::{_set_partition_size, accumulate_dataframes_vertical_unchecked}; +use polars_ops::prelude::JoinArgs; use polars_utils::arena::Node; use polars_utils::slice::GetSaferUnchecked; use polars_utils::unitvec; @@ -34,6 +35,7 @@ pub struct GenericBuild { materialized_join_cols: Vec>, suffix: Arc, hb: RandomState, + join_args: JoinArgs, // partitioned tables that will be used for probing // stores the key and the chunk_idx, df_idx of the left table hash_tables: PartitionedMap, @@ -45,7 +47,6 @@ pub struct GenericBuild { // amortize allocations join_columns: Vec, hashes: Vec, - join_type: JoinType, // the join order is swapped to ensure we hash the smaller table swapped: bool, join_nulls: bool, @@ -59,7 +60,7 @@ impl GenericBuild { #[allow(clippy::too_many_arguments)] pub(crate) fn new( suffix: Arc, - join_type: JoinType, + join_args: JoinArgs, swapped: bool, join_columns_left: Arc>>, join_columns_right: Arc>>, @@ -76,7 +77,7 @@ impl GenericBuild { })); GenericBuild { chunks: vec![], - join_type, + join_args, suffix, hb, swapped, @@ -278,7 +279,7 @@ impl Sink for GenericBuild { fn split(&self, _thread_no: usize) -> Box { let mut new = Self::new( self.suffix.clone(), - self.join_type.clone(), + self.join_args.clone(), self.swapped, self.join_columns_left.clone(), self.join_columns_right.clone(), @@ -317,7 +318,7 @@ impl Sink for GenericBuild { let mut hashes = std::mem::take(&mut self.hashes); hashes.clear(); - match self.join_type { + match self.join_args.how { JoinType::Inner | JoinType::Left => { let probe_operator = GenericJoinProbe::new( left_df, @@ -330,13 +331,14 @@ impl Sink for GenericBuild { self.swapped, hashes, context, - self.join_type.clone(), + self.join_args.how.clone(), self.join_nulls, ); self.placeholder.replace(Box::new(probe_operator)); Ok(FinalizedSink::Operator) }, - JoinType::Outer { coalesce } => { + JoinType::Outer => { + let coalesce = self.join_args.coalesce.coalesce(&JoinType::Outer); let probe_operator = GenericOuterJoinProbe::new( left_df, materialized_join_cols, diff --git a/crates/polars-pipe/src/pipeline/convert.rs b/crates/polars-pipe/src/pipeline/convert.rs index 73afa86fea0a..a0e4aee37ee8 100644 --- a/crates/polars-pipe/src/pipeline/convert.rs +++ b/crates/polars-pipe/src/pipeline/convert.rs @@ -285,12 +285,12 @@ where }; match jt { - join_type @ JoinType::Inner | join_type @ JoinType::Left => { + JoinType::Inner | JoinType::Left => { let (join_columns_left, join_columns_right) = swap_eval(); Box::new(GenericBuild::<()>::new( Arc::from(options.args.suffix()), - join_type.clone(), + options.args.clone(), swapped, join_columns_left, join_columns_right, @@ -317,7 +317,7 @@ where Box::new(GenericBuild::::new( Arc::from(options.args.suffix()), - jt.clone(), + options.args.clone(), swapped, join_columns_left, join_columns_right, diff --git a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/joins.rs b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/joins.rs index fbdb528c6ed7..10e108d26008 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/joins.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/joins.rs @@ -258,7 +258,8 @@ pub(super) fn process_join( already_added_local_to_local_projected.insert(local_name); } // In outer joins both columns remain. So `add_local=true` also for the right table - let add_local = matches!(options.args.how, JoinType::Outer { coalesce: false }); + let add_local = matches!(options.args.how, JoinType::Outer) + && !options.args.coalesce.coalesce(&options.args.how); for e in &right_on { // In case of outer joins we also add the columns. // But before we do that we must check if the column wasn't already added by the lhs. diff --git a/crates/polars-plan/src/logical_plan/schema.rs b/crates/polars-plan/src/logical_plan/schema.rs index cc7a298eba13..7d7044e498e1 100644 --- a/crates/polars-plan/src/logical_plan/schema.rs +++ b/crates/polars-plan/src/logical_plan/schema.rs @@ -313,11 +313,11 @@ pub(crate) fn det_join_schema( new_schema.with_column(field.name, field.dtype); arena.clear(); } - // except in asof joins. Asof joins are not equi-joins + // Except in asof joins. Asof joins are not equi-joins // so the columns that are joined on, may have different // values so if the right has a different name, it is added to the schema #[cfg(feature = "asof_join")] - if !options.args.how.merges_join_keys() { + if !options.args.coalesce.coalesce(&options.args.how) { for (left_on, right_on) in left_on.iter().zip(right_on) { let field_left = left_on.to_field_amortized(schema_left, Context::Default, &mut arena)?; @@ -342,10 +342,13 @@ pub(crate) fn det_join_schema( join_on_right.insert(field.name); } + let are_coalesced = options.args.coalesce.coalesce(&options.args.how); + let is_asof = options.args.how.is_asof(); + + // Asof joins are special, if the names are equal they will not be coalesced. for (name, dtype) in schema_right.iter() { - if !join_on_right.contains(name.as_str()) // The names that are joined on are merged - || matches!(&options.args.how, JoinType::Outer{coalesce: false}) - // The names are not merged + if !join_on_right.contains(name.as_str()) || (!are_coalesced && !is_asof) + // The names that are joined on are merged { if schema_left.contains(name.as_str()) { #[cfg(feature = "asof_join")] diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index a066bd91fd13..6fc6ac559968 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -319,14 +319,9 @@ impl SQLContext { let (r_name, rf) = self.get_table(&tbl.relation)?; lf = match &tbl.join_operator { JoinOperator::CrossJoin => lf.cross_join(rf), - JoinOperator::FullOuter(constraint) => process_join( - lf, - rf, - constraint, - &l_name, - &r_name, - JoinType::Outer { coalesce: false }, - )?, + JoinOperator::FullOuter(constraint) => { + process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Outer)? + }, JoinOperator::Inner(constraint) => { process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Inner)? }, diff --git a/crates/polars/tests/it/core/joins.rs b/crates/polars/tests/it/core/joins.rs index 212de7960562..0542e77f96f1 100644 --- a/crates/polars/tests/it/core/joins.rs +++ b/crates/polars/tests/it/core/joins.rs @@ -119,7 +119,7 @@ fn test_outer_join() -> PolarsResult<()> { &rain, ["days"], ["days"], - JoinArgs::new(JoinType::Outer { coalesce: true }), + JoinArgs::new(JoinType::Outer).with_coalesce(JoinCoalesce::CoalesceColumns), )?; assert_eq!(joined.height(), 5); assert_eq!(joined.column("days")?.sum::().unwrap(), 7); @@ -139,7 +139,7 @@ fn test_outer_join() -> PolarsResult<()> { &df_right, ["a"], ["a"], - JoinArgs::new(JoinType::Outer { coalesce: true }), + JoinArgs::new(JoinType::Outer).with_coalesce(JoinCoalesce::CoalesceColumns), )?; assert_eq!(out.column("c_right")?.null_count(), 1); @@ -254,7 +254,7 @@ fn test_join_multiple_columns() { &df_b, ["a", "b"], ["foo", "bar"], - JoinType::Outer { coalesce: true }.into(), + JoinArgs::new(JoinType::Outer).with_coalesce(JoinCoalesce::CoalesceColumns), ) .unwrap(); assert!(joined_outer_hack @@ -300,11 +300,7 @@ fn test_join_categorical() { assert_eq!(Vec::from(ca), correct_ham); // test dispatch - for jt in [ - JoinType::Left, - JoinType::Inner, - JoinType::Outer { coalesce: true }, - ] { + for jt in [JoinType::Left, JoinType::Inner, JoinType::Outer] { let out = df_a.join(&df_b, ["b"], ["bar"], jt.into()).unwrap(); let out = out.column("b").unwrap(); assert_eq!( @@ -471,7 +467,7 @@ fn test_joins_with_duplicates() -> PolarsResult<()> { &df_right, ["col1"], ["join_col1"], - JoinArgs::new(JoinType::Outer { coalesce: true }), + JoinArgs::new(JoinType::Outer).with_coalesce(JoinCoalesce::CoalesceColumns), ) .unwrap(); @@ -543,7 +539,7 @@ fn test_multi_joins_with_duplicates() -> PolarsResult<()> { &df_right, &["col1", "join_col2"], &["join_col1", "col2"], - JoinType::Outer { coalesce: true }.into(), + JoinArgs::new(JoinType::Outer).with_coalesce(JoinCoalesce::CoalesceColumns), ) .unwrap(); @@ -586,7 +582,7 @@ fn test_join_floats() -> PolarsResult<()> { &df_b, vec!["a", "c"], vec!["foo", "bar"], - JoinType::Outer { coalesce: true }.into(), + JoinArgs::new(JoinType::Outer).with_coalesce(JoinCoalesce::CoalesceColumns), )?; assert_eq!( out.dtypes(), diff --git a/crates/polars/tests/it/joins.rs b/crates/polars/tests/it/joins.rs index 2e5435d1bd2c..80e9c31739b2 100644 --- a/crates/polars/tests/it/joins.rs +++ b/crates/polars/tests/it/joins.rs @@ -23,7 +23,8 @@ fn join_nans_outer() -> PolarsResult<()> { .with(a2) .left_on(vec![col("w"), col("t")]) .right_on(vec![col("w"), col("t")]) - .how(JoinType::Outer { coalesce: true }) + .how(JoinType::Outer) + .coalesce(JoinCoalesce::CoalesceColumns) .join_nulls(true) .finish() .collect()?; diff --git a/crates/polars/tests/it/lazy/projection_queries.rs b/crates/polars/tests/it/lazy/projection_queries.rs index 92035bef6a37..56a43e6efed4 100644 --- a/crates/polars/tests/it/lazy/projection_queries.rs +++ b/crates/polars/tests/it/lazy/projection_queries.rs @@ -54,7 +54,7 @@ fn test_outer_join_with_column_2988() -> PolarsResult<()> { ldf2, [col("key1"), col("key2")], [col("key1"), col("key2")], - JoinType::Outer { coalesce: true }.into(), + JoinArgs::new(JoinType::Outer).with_coalesce(JoinCoalesce::CoalesceColumns), ) .with_columns([col("key1")]) .collect()?; diff --git a/docs/src/rust/user-guide/transformations/joins.rs b/docs/src/rust/user-guide/transformations/joins.rs index 5c0526bba90a..cb557d31be18 100644 --- a/docs/src/rust/user-guide/transformations/joins.rs +++ b/docs/src/rust/user-guide/transformations/joins.rs @@ -58,7 +58,7 @@ fn main() -> Result<(), Box> { df_orders.clone().lazy(), [col("customer_id")], [col("customer_id")], - JoinArgs::new(JoinType::Outer { coalesce: false }), + JoinArgs::new(JoinType::Outer), ) .collect()?; println!("{}", &df_outer_join); @@ -72,7 +72,7 @@ fn main() -> Result<(), Box> { df_orders.clone().lazy(), [col("customer_id")], [col("customer_id")], - JoinArgs::new(JoinType::Outer { coalesce: true }), + JoinArgs::new(JoinType::Outer), ) .collect()?; println!("{}", &df_outer_join); diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index a9bac9e10730..229913c54f8d 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -3974,6 +3974,10 @@ def join( msg = "must specify `on` OR `left_on` and `right_on`" raise ValueError(msg) + coalesce = None + if how == "outer_coalesce": + coalesce = True + return self._from_pyldf( self._ldf.join( other._ldf, @@ -3985,6 +3989,7 @@ def join( how, suffix, validate, + coalesce, ) ) diff --git a/py-polars/src/conversion/mod.rs b/py-polars/src/conversion/mod.rs index 36351164a83f..cd4ea745bdc1 100644 --- a/py-polars/src/conversion/mod.rs +++ b/py-polars/src/conversion/mod.rs @@ -701,8 +701,11 @@ impl FromPyObject<'_> for Wrap { let parsed = match &*ob.extract::()? { "inner" => JoinType::Inner, "left" => JoinType::Left, - "outer" => JoinType::Outer{coalesce: false}, - "outer_coalesce" => JoinType::Outer{coalesce: true}, + "outer" => JoinType::Outer, + "outer_coalesce" => { + // TODO! deprecate + JoinType::Outer + }, "semi" => JoinType::Semi, "anti" => JoinType::Anti, #[cfg(feature = "cross_join")] diff --git a/py-polars/src/lazyframe/mod.rs b/py-polars/src/lazyframe/mod.rs index 253210cb18d9..c291a3411e1e 100644 --- a/py-polars/src/lazyframe/mod.rs +++ b/py-polars/src/lazyframe/mod.rs @@ -878,7 +878,13 @@ impl PyLazyFrame { how: Wrap, suffix: String, validate: Wrap, + coalesce: Option, ) -> PyResult { + let coalesce = match coalesce { + None => JoinCoalesce::JoinSpecific, + Some(true) => JoinCoalesce::CoalesceColumns, + Some(false) => JoinCoalesce::KeepColumns, + }; let ldf = self.ldf.clone(); let other = other.ldf; let left_on = left_on @@ -899,6 +905,7 @@ impl PyLazyFrame { .force_parallel(force_parallel) .join_nulls(join_nulls) .how(how.0) + .coalesce(coalesce) .validate(validate.0) .suffix(suffix) .finish()