Skip to content

Commit f51cd6e

Browse files
authored
Add DataFrame fill_null (#14769)
* feat: add fill_null methods to DataFrame for handling null values * test: refactor fill_null tests and create helper function for null table * style: reorder imports in mod.rs for better organization * clippy lint * test: add comment to clarify test * refactor: columns Vec<String> * docs: enhance fill_null documentation with example usage * test: columns Vec<String> * docs: update fill_null documentation with detailed usage examples
1 parent dd7fe8f commit f51cd6e

File tree

2 files changed

+186
-5
lines changed

2 files changed

+186
-5
lines changed

datafusion/core/src/dataframe/mod.rs

+92-5
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,18 @@ use arrow::compute::{cast, concat};
5151
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
5252
use datafusion_common::config::{CsvOptions, JsonOptions};
5353
use datafusion_common::{
54-
exec_err, not_impl_err, plan_err, Column, DFSchema, DataFusionError, ParamValues,
55-
SchemaError, UnnestOptions,
54+
exec_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema,
55+
DataFusionError, ParamValues, ScalarValue, SchemaError, UnnestOptions,
5656
};
57-
use datafusion_expr::dml::InsertOp;
58-
use datafusion_expr::{case, is_null, lit, SortExpr};
5957
use datafusion_expr::{
60-
utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE,
58+
case,
59+
dml::InsertOp,
60+
expr::{Alias, ScalarFunction},
61+
is_null, lit,
62+
utils::COUNT_STAR_EXPANSION,
63+
SortExpr, TableProviderFilterPushDown, UNNAMED_TABLE,
6164
};
65+
use datafusion_functions::core::coalesce;
6266
use datafusion_functions_aggregate::expr_fn::{
6367
avg, count, max, median, min, stddev, sum,
6468
};
@@ -1930,6 +1934,89 @@ impl DataFrame {
19301934
plan,
19311935
})
19321936
}
1937+
1938+
/// Fill null values in specified columns with a given value
1939+
/// If no columns are specified (empty vector), applies to all columns
1940+
/// Only fills if the value can be cast to the column's type
1941+
///
1942+
/// # Arguments
1943+
/// * `value` - Value to fill nulls with
1944+
/// * `columns` - List of column names to fill. If empty, fills all columns.
1945+
///
1946+
/// # Example
1947+
/// ```
1948+
/// # use datafusion::prelude::*;
1949+
/// # use datafusion::error::Result;
1950+
/// # use datafusion_common::ScalarValue;
1951+
/// # #[tokio::main]
1952+
/// # async fn main() -> Result<()> {
1953+
/// let ctx = SessionContext::new();
1954+
/// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?;
1955+
/// // Fill nulls in only columns "a" and "c":
1956+
/// let df = df.fill_null(ScalarValue::from(0), vec!["a".to_owned(), "c".to_owned()])?;
1957+
/// // Fill nulls across all columns:
1958+
/// let df = df.fill_null(ScalarValue::from(0), vec![])?;
1959+
/// # Ok(())
1960+
/// # }
1961+
/// ```
1962+
pub fn fill_null(
1963+
&self,
1964+
value: ScalarValue,
1965+
columns: Vec<String>,
1966+
) -> Result<DataFrame> {
1967+
let cols = if columns.is_empty() {
1968+
self.logical_plan()
1969+
.schema()
1970+
.fields()
1971+
.iter()
1972+
.map(|f| f.as_ref().clone())
1973+
.collect()
1974+
} else {
1975+
self.find_columns(&columns)?
1976+
};
1977+
1978+
// Create projections for each column
1979+
let projections = self
1980+
.logical_plan()
1981+
.schema()
1982+
.fields()
1983+
.iter()
1984+
.map(|field| {
1985+
if cols.contains(field) {
1986+
// Try to cast fill value to column type. If the cast fails, fallback to the original column.
1987+
match value.clone().cast_to(field.data_type()) {
1988+
Ok(fill_value) => Expr::Alias(Alias {
1989+
expr: Box::new(Expr::ScalarFunction(ScalarFunction {
1990+
func: coalesce(),
1991+
args: vec![col(field.name()), lit(fill_value)],
1992+
})),
1993+
relation: None,
1994+
name: field.name().to_string(),
1995+
}),
1996+
Err(_) => col(field.name()),
1997+
}
1998+
} else {
1999+
col(field.name())
2000+
}
2001+
})
2002+
.collect::<Vec<_>>();
2003+
2004+
self.clone().select(projections)
2005+
}
2006+
2007+
// Helper to find columns from names
2008+
fn find_columns(&self, names: &[String]) -> Result<Vec<Field>> {
2009+
let schema = self.logical_plan().schema();
2010+
names
2011+
.iter()
2012+
.map(|name| {
2013+
schema
2014+
.field_with_name(None, name)
2015+
.cloned()
2016+
.map_err(|_| plan_datafusion_err!("Column '{}' not found", name))
2017+
})
2018+
.collect()
2019+
}
19332020
}
19342021

19352022
#[derive(Debug)]

datafusion/core/tests/dataframe/mod.rs

+94
Original file line numberDiff line numberDiff line change
@@ -5342,3 +5342,97 @@ async fn test_insert_into_checking() -> Result<()> {
53425342

53435343
Ok(())
53445344
}
5345+
5346+
async fn create_null_table() -> Result<DataFrame> {
5347+
// create a DataFrame with null values
5348+
// "+---+----+",
5349+
// "| a | b |",
5350+
// "+---+---+",
5351+
// "| 1 | x |",
5352+
// "| | |",
5353+
// "| 3 | z |",
5354+
// "+---+---+",
5355+
let schema = Arc::new(Schema::new(vec![
5356+
Field::new("a", DataType::Int32, true),
5357+
Field::new("b", DataType::Utf8, true),
5358+
]));
5359+
let a_values = Int32Array::from(vec![Some(1), None, Some(3)]);
5360+
let b_values = StringArray::from(vec![Some("x"), None, Some("z")]);
5361+
let batch = RecordBatch::try_new(
5362+
schema.clone(),
5363+
vec![Arc::new(a_values), Arc::new(b_values)],
5364+
)?;
5365+
5366+
let ctx = SessionContext::new();
5367+
let table = MemTable::try_new(schema.clone(), vec![vec![batch]])?;
5368+
ctx.register_table("t_null", Arc::new(table))?;
5369+
let df = ctx.table("t_null").await?;
5370+
Ok(df)
5371+
}
5372+
5373+
#[tokio::test]
5374+
async fn test_fill_null() -> Result<()> {
5375+
let df = create_null_table().await?;
5376+
5377+
// Use fill_null to replace nulls on each column.
5378+
let df_filled = df
5379+
.fill_null(ScalarValue::Int32(Some(0)), vec!["a".to_string()])?
5380+
.fill_null(
5381+
ScalarValue::Utf8(Some("default".to_string())),
5382+
vec!["b".to_string()],
5383+
)?;
5384+
5385+
let results = df_filled.collect().await?;
5386+
let expected = [
5387+
"+---+---------+",
5388+
"| a | b |",
5389+
"+---+---------+",
5390+
"| 1 | x |",
5391+
"| 0 | default |",
5392+
"| 3 | z |",
5393+
"+---+---------+",
5394+
];
5395+
assert_batches_sorted_eq!(expected, &results);
5396+
Ok(())
5397+
}
5398+
5399+
#[tokio::test]
5400+
async fn test_fill_null_all_columns() -> Result<()> {
5401+
let df = create_null_table().await?;
5402+
5403+
// Use fill_null to replace nulls on all columns.
5404+
// Only column "b" will be replaced since ScalarValue::Utf8(Some("default".to_string()))
5405+
// can be cast to Utf8.
5406+
let df_filled =
5407+
df.fill_null(ScalarValue::Utf8(Some("default".to_string())), vec![])?;
5408+
5409+
let results = df_filled.clone().collect().await?;
5410+
5411+
let expected = [
5412+
"+---+---------+",
5413+
"| a | b |",
5414+
"+---+---------+",
5415+
"| 1 | x |",
5416+
"| | default |",
5417+
"| 3 | z |",
5418+
"+---+---------+",
5419+
];
5420+
5421+
assert_batches_sorted_eq!(expected, &results);
5422+
5423+
// Fill column "a" null values with a value that cannot be cast to Int32.
5424+
let df_filled = df_filled.fill_null(ScalarValue::Int32(Some(0)), vec![])?;
5425+
5426+
let results = df_filled.collect().await?;
5427+
let expected = [
5428+
"+---+---------+",
5429+
"| a | b |",
5430+
"+---+---------+",
5431+
"| 1 | x |",
5432+
"| 0 | default |",
5433+
"| 3 | z |",
5434+
"+---+---------+",
5435+
];
5436+
assert_batches_sorted_eq!(expected, &results);
5437+
Ok(())
5438+
}

0 commit comments

Comments
 (0)