diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 6f540fa02c75..2883f4586c5b 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -50,14 +50,18 @@ use arrow::compute::{cast, concat}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::{CsvOptions, JsonOptions}; use datafusion_common::{ - exec_err, not_impl_err, plan_err, Column, DFSchema, DataFusionError, ParamValues, - SchemaError, UnnestOptions, + exec_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, + DataFusionError, ParamValues, ScalarValue, SchemaError, UnnestOptions, }; -use datafusion_expr::dml::InsertOp; -use datafusion_expr::{case, is_null, lit, SortExpr}; use datafusion_expr::{ - utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, + case, + dml::InsertOp, + expr::{Alias, ScalarFunction}, + is_null, lit, + utils::COUNT_STAR_EXPANSION, + SortExpr, TableProviderFilterPushDown, UNNAMED_TABLE, }; +use datafusion_functions::core::coalesce; use datafusion_functions_aggregate::expr_fn::{ avg, count, max, median, min, stddev, sum, }; @@ -1926,6 +1930,89 @@ impl DataFrame { plan, }) } + + /// Fill null values in specified columns with a given value + /// If no columns are specified (empty vector), applies to all columns + /// Only fills if the value can be cast to the column's type + /// + /// # Arguments + /// * `value` - Value to fill nulls with + /// * `columns` - List of column names to fill. If empty, fills all columns. + /// + /// # Example + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # use datafusion_common::ScalarValue; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// // Fill nulls in only columns "a" and "c": + /// let df = df.fill_null(ScalarValue::from(0), vec!["a".to_owned(), "c".to_owned()])?; + /// // Fill nulls across all columns: + /// let df = df.fill_null(ScalarValue::from(0), vec![])?; + /// # Ok(()) + /// # } + /// ``` + pub fn fill_null( + &self, + value: ScalarValue, + columns: Vec, + ) -> Result { + let cols = if columns.is_empty() { + self.logical_plan() + .schema() + .fields() + .iter() + .map(|f| f.as_ref().clone()) + .collect() + } else { + self.find_columns(&columns)? + }; + + // Create projections for each column + let projections = self + .logical_plan() + .schema() + .fields() + .iter() + .map(|field| { + if cols.contains(field) { + // Try to cast fill value to column type. If the cast fails, fallback to the original column. + match value.clone().cast_to(field.data_type()) { + Ok(fill_value) => Expr::Alias(Alias { + expr: Box::new(Expr::ScalarFunction(ScalarFunction { + func: coalesce(), + args: vec![col(field.name()), lit(fill_value)], + })), + relation: None, + name: field.name().to_string(), + }), + Err(_) => col(field.name()), + } + } else { + col(field.name()) + } + }) + .collect::>(); + + self.clone().select(projections) + } + + // Helper to find columns from names + fn find_columns(&self, names: &[String]) -> Result> { + let schema = self.logical_plan().schema(); + names + .iter() + .map(|name| { + schema + .field_with_name(None, name) + .cloned() + .map_err(|_| plan_datafusion_err!("Column '{}' not found", name)) + }) + .collect() + } } #[derive(Debug)] diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index d545157607c7..ae7e46a1ef89 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -5342,3 +5342,97 @@ async fn test_insert_into_checking() -> Result<()> { Ok(()) } + +async fn create_null_table() -> Result { + // create a DataFrame with null values + // "+---+----+", + // "| a | b |", + // "+---+---+", + // "| 1 | x |", + // "| | |", + // "| 3 | z |", + // "+---+---+", + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ])); + let a_values = Int32Array::from(vec![Some(1), None, Some(3)]); + let b_values = StringArray::from(vec![Some("x"), None, Some("z")]); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(a_values), Arc::new(b_values)], + )?; + + let ctx = SessionContext::new(); + let table = MemTable::try_new(schema.clone(), vec![vec![batch]])?; + ctx.register_table("t_null", Arc::new(table))?; + let df = ctx.table("t_null").await?; + Ok(df) +} + +#[tokio::test] +async fn test_fill_null() -> Result<()> { + let df = create_null_table().await?; + + // Use fill_null to replace nulls on each column. + let df_filled = df + .fill_null(ScalarValue::Int32(Some(0)), vec!["a".to_string()])? + .fill_null( + ScalarValue::Utf8(Some("default".to_string())), + vec!["b".to_string()], + )?; + + let results = df_filled.collect().await?; + let expected = [ + "+---+---------+", + "| a | b |", + "+---+---------+", + "| 1 | x |", + "| 0 | default |", + "| 3 | z |", + "+---+---------+", + ]; + assert_batches_sorted_eq!(expected, &results); + Ok(()) +} + +#[tokio::test] +async fn test_fill_null_all_columns() -> Result<()> { + let df = create_null_table().await?; + + // Use fill_null to replace nulls on all columns. + // Only column "b" will be replaced since ScalarValue::Utf8(Some("default".to_string())) + // can be cast to Utf8. + let df_filled = + df.fill_null(ScalarValue::Utf8(Some("default".to_string())), vec![])?; + + let results = df_filled.clone().collect().await?; + + let expected = [ + "+---+---------+", + "| a | b |", + "+---+---------+", + "| 1 | x |", + "| | default |", + "| 3 | z |", + "+---+---------+", + ]; + + assert_batches_sorted_eq!(expected, &results); + + // Fill column "a" null values with a value that cannot be cast to Int32. + let df_filled = df_filled.fill_null(ScalarValue::Int32(Some(0)), vec![])?; + + let results = df_filled.collect().await?; + let expected = [ + "+---+---------+", + "| a | b |", + "+---+---------+", + "| 1 | x |", + "| 0 | default |", + "| 3 | z |", + "+---+---------+", + ]; + assert_batches_sorted_eq!(expected, &results); + Ok(()) +}