Skip to content

Commit 161d0f2

Browse files
authored
fix dml logical plan output schema (#10394)
* fix dml logical plan output schema Previously, `LogicalPlan::schema` would return the input schema for Dml plans, rather than the expected output schema. This is an unusal case since Dmls are typically not run for their output, but it is typical for the output to be the `count` of rows affected by the DML statement. See `fn dml_output_schema` for a test. * document DmlStatement::new * Fix expected logical schema of 'insert into' in sqllogictests
1 parent 0681004 commit 161d0f2

File tree

12 files changed

+116
-72
lines changed

12 files changed

+116
-72
lines changed

datafusion/core/tests/sql/sql_api.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,19 @@ async fn unsupported_dml_returns_error() {
5858
ctx.sql_with_options(sql, options).await.unwrap();
5959
}
6060

61+
#[tokio::test]
62+
async fn dml_output_schema() {
63+
use arrow::datatypes::Schema;
64+
use arrow::datatypes::{DataType, Field};
65+
66+
let ctx = SessionContext::new();
67+
ctx.sql("CREATE TABLE test (x int)").await.unwrap();
68+
let sql = "INSERT INTO test VALUES (1)";
69+
let df = ctx.sql(sql).await.unwrap();
70+
let count_schema = Schema::new(vec![Field::new("count", DataType::UInt64, false)]);
71+
assert_eq!(Schema::from(df.schema()), count_schema);
72+
}
73+
6174
#[tokio::test]
6275
async fn unsupported_copy_returns_error() {
6376
let tmpdir = TempDir::new().unwrap();

datafusion/expr/src/logical_plan/builder.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -296,12 +296,12 @@ impl LogicalPlanBuilder {
296296
WriteOp::InsertInto
297297
};
298298

299-
Ok(Self::from(LogicalPlan::Dml(DmlStatement {
300-
table_name: table_name.into(),
299+
Ok(Self::from(LogicalPlan::Dml(DmlStatement::new(
300+
table_name.into(),
301301
table_schema,
302302
op,
303-
input: Arc::new(input),
304-
})))
303+
Arc::new(input),
304+
))))
305305
}
306306

307307
/// Convert a table provider into a builder with a TableScan

datafusion/expr/src/logical_plan/dml.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use std::fmt::{self, Display};
2020
use std::hash::{Hash, Hasher};
2121
use std::sync::Arc;
2222

23+
use arrow::datatypes::{DataType, Field, Schema};
2324
use datafusion_common::config::FormatOptions;
2425
use datafusion_common::{DFSchemaRef, TableReference};
2526

@@ -70,9 +71,29 @@ pub struct DmlStatement {
7071
pub op: WriteOp,
7172
/// The relation that determines the tuples to add/remove/modify the schema must match with table_schema
7273
pub input: Arc<LogicalPlan>,
74+
/// The schema of the output relation
75+
pub output_schema: DFSchemaRef,
7376
}
7477

7578
impl DmlStatement {
79+
/// Creates a new DML statement with the output schema set to a single `count` column.
80+
pub fn new(
81+
table_name: TableReference,
82+
table_schema: DFSchemaRef,
83+
op: WriteOp,
84+
input: Arc<LogicalPlan>,
85+
) -> Self {
86+
Self {
87+
table_name,
88+
table_schema,
89+
op,
90+
input,
91+
92+
// The output schema is always a single column with the number of rows affected
93+
output_schema: make_count_schema(),
94+
}
95+
}
96+
7697
/// Return a descriptive name of this [`DmlStatement`]
7798
pub fn name(&self) -> &str {
7899
self.op.name()
@@ -106,3 +127,11 @@ impl Display for WriteOp {
106127
write!(f, "{}", self.name())
107128
}
108129
}
130+
131+
fn make_count_schema() -> DFSchemaRef {
132+
Arc::new(
133+
Schema::new(vec![Field::new("count", DataType::UInt64, false)])
134+
.try_into()
135+
.unwrap(),
136+
)
137+
}

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ impl LogicalPlan {
191191
LogicalPlan::DescribeTable(DescribeTable { output_schema, .. }) => {
192192
output_schema
193193
}
194-
LogicalPlan::Dml(DmlStatement { table_schema, .. }) => table_schema,
194+
LogicalPlan::Dml(DmlStatement { output_schema, .. }) => output_schema,
195195
LogicalPlan::Copy(CopyTo { input, .. }) => input.schema(),
196196
LogicalPlan::Ddl(ddl) => ddl.schema(),
197197
LogicalPlan::Unnest(Unnest { schema, .. }) => schema,
@@ -509,12 +509,12 @@ impl LogicalPlan {
509509
table_schema,
510510
op,
511511
..
512-
}) => Ok(LogicalPlan::Dml(DmlStatement {
513-
table_name: table_name.clone(),
514-
table_schema: table_schema.clone(),
515-
op: op.clone(),
516-
input: Arc::new(inputs.swap_remove(0)),
517-
})),
512+
}) => Ok(LogicalPlan::Dml(DmlStatement::new(
513+
table_name.clone(),
514+
table_schema.clone(),
515+
op.clone(),
516+
Arc::new(inputs.swap_remove(0)),
517+
))),
518518
LogicalPlan::Copy(CopyTo {
519519
input: _,
520520
output_url,

datafusion/expr/src/logical_plan/tree_node.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,12 +242,14 @@ impl TreeNode for LogicalPlan {
242242
table_schema,
243243
op,
244244
input,
245+
output_schema,
245246
}) => rewrite_arc(input, f)?.update_data(|input| {
246247
LogicalPlan::Dml(DmlStatement {
247248
table_name,
248249
table_schema,
249250
op,
250251
input,
252+
output_schema,
251253
})
252254
}),
253255
LogicalPlan::Copy(CopyTo {

datafusion/sql/src/statement.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,12 +1206,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
12061206
}
12071207
};
12081208

1209-
let plan = LogicalPlan::Dml(DmlStatement {
1210-
table_name: table_ref,
1211-
table_schema: schema.into(),
1212-
op: WriteOp::Delete,
1213-
input: Arc::new(source),
1214-
});
1209+
let plan = LogicalPlan::Dml(DmlStatement::new(
1210+
table_ref,
1211+
schema.into(),
1212+
WriteOp::Delete,
1213+
Arc::new(source),
1214+
));
12151215
Ok(plan)
12161216
}
12171217

@@ -1318,12 +1318,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
13181318

13191319
let source = project(source, exprs)?;
13201320

1321-
let plan = LogicalPlan::Dml(DmlStatement {
1321+
let plan = LogicalPlan::Dml(DmlStatement::new(
13221322
table_name,
13231323
table_schema,
1324-
op: WriteOp::Update,
1325-
input: Arc::new(source),
1326-
});
1324+
WriteOp::Update,
1325+
Arc::new(source),
1326+
));
13271327
Ok(plan)
13281328
}
13291329

@@ -1441,12 +1441,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
14411441
WriteOp::InsertInto
14421442
};
14431443

1444-
let plan = LogicalPlan::Dml(DmlStatement {
1444+
let plan = LogicalPlan::Dml(DmlStatement::new(
14451445
table_name,
1446-
table_schema: Arc::new(table_schema),
1446+
Arc::new(table_schema),
14471447
op,
1448-
input: Arc::new(source),
1449-
});
1448+
Arc::new(source),
1449+
));
14501450
Ok(plan)
14511451
}
14521452

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3477,7 +3477,7 @@ SELECT STRING_AGG(column1, '|') FROM (values (''), (null), (''));
34773477
statement ok
34783478
CREATE TABLE strings(g INTEGER, x VARCHAR, y VARCHAR)
34793479

3480-
query ITT
3480+
query I
34813481
INSERT INTO strings VALUES (1,'a','/'), (1,'b','-'), (2,'i','/'), (2,NULL,'-'), (2,'j','+'), (3,'p','/'), (4,'x','/'), (4,'y','-'), (4,'z','+')
34823482
----
34833483
9

datafusion/sqllogictest/test_files/array.slt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6475,7 +6475,7 @@ create table test_create_array_table(
64756475
d int
64766476
);
64776477

6478-
query ???I
6478+
query I
64796479
insert into test_create_array_table values
64806480
([1, 2, 3], ['a', 'b', 'c'], [[4,6], [6,7,8]], 1);
64816481
----

datafusion/sqllogictest/test_files/create_external_table.slt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ PARTITIONED BY (p1 string, p2 string)
130130
STORED AS parquet
131131
LOCATION 'test_files/scratch/create_external_table/bad_partitioning/';
132132

133-
query ITT
133+
query I
134134
INSERT INTO partitioned VALUES (1, 'x', 'y');
135135
----
136136
1
@@ -186,13 +186,13 @@ PARTITIONED BY (month string, year string)
186186
STORED AS parquet
187187
LOCATION 'test_files/scratch/create_external_table/manual_partitioning/';
188188

189-
query TTT
189+
query I
190190
-- creates year -> month partitions
191191
INSERT INTO test VALUES('name', '2024', '03');
192192
----
193193
1
194194

195-
query TTT
195+
query I
196196
-- creates month -> year partitions.
197197
-- now table have both partitions (year -> month and month -> year)
198198
INSERT INTO test2 VALUES('name', '2024', '03');

datafusion/sqllogictest/test_files/insert.slt

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ physical_plan
7575
09)----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1
7676
10)------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c4, c9], has_header=true
7777

78-
query II
78+
query I
7979
INSERT INTO table_without_values SELECT
8080
SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),
8181
COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)
@@ -137,7 +137,7 @@ physical_plan
137137

138138

139139

140-
query II
140+
query I
141141
INSERT INTO table_without_values SELECT
142142
SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a1,
143143
COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a2
@@ -187,7 +187,7 @@ physical_plan
187187
10)------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c4, c9], has_header=true
188188

189189

190-
query II
190+
query I
191191
INSERT INTO table_without_values SELECT
192192
SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a1,
193193
COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a2
@@ -221,7 +221,7 @@ physical_plan
221221
02)--SortExec: expr=[c1@0 ASC NULLS LAST], preserve_partitioning=[false]
222222
03)----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true
223223

224-
query T
224+
query I
225225
insert into table_without_values select c1 from aggregate_test_100 order by c1;
226226
----
227227
100
@@ -239,12 +239,12 @@ drop table table_without_values;
239239
statement ok
240240
CREATE TABLE table_without_values(id BIGINT, name varchar);
241241

242-
query IT
242+
query I
243243
insert into table_without_values(id, name) values(1, 'foo');
244244
----
245245
1
246246

247-
query IT
247+
query I
248248
insert into table_without_values(name, id) values('bar', 2);
249249
----
250250
1
@@ -259,7 +259,7 @@ statement error Error during planning: Column count doesn't match insert query!
259259
insert into table_without_values(id) values(4, 'zoo');
260260

261261
# insert NULL values for the missing column (name)
262-
query IT
262+
query I
263263
insert into table_without_values(id) values(4);
264264
----
265265
1
@@ -279,18 +279,18 @@ drop table table_without_values;
279279
statement ok
280280
CREATE TABLE table_without_values(field1 BIGINT NOT NULL, field2 BIGINT NULL);
281281

282-
query II
282+
query I
283283
insert into table_without_values values(1, 100);
284284
----
285285
1
286286

287-
query II
287+
query I
288288
insert into table_without_values values(2, NULL);
289289
----
290290
1
291291

292292
# insert NULL values for the missing column (field2)
293-
query II
293+
query I
294294
insert into table_without_values(field1) values(3);
295295
----
296296
1
@@ -363,15 +363,15 @@ create table test_column_defaults(
363363
e timestamp default now()
364364
)
365365

366-
query IIITP
366+
query I
367367
insert into test_column_defaults values(1, 10, 100, 'ABC', now())
368368
----
369369
1
370370

371371
statement error DataFusion error: Execution error: Invalid batch column at '1' has null but schema specifies non-nullable
372372
insert into test_column_defaults(a) values(2)
373373

374-
query IIITP
374+
query I
375375
insert into test_column_defaults(b) values(20)
376376
----
377377
1
@@ -383,7 +383,7 @@ select a,b,c,d from test_column_defaults
383383
NULL 20 500 default_text
384384

385385
# fill the timestamp column with default value `now()` again, it should be different from the previous one
386-
query IIITP
386+
query I
387387
insert into test_column_defaults(a, b, c, d) values(2, 20, 200, 'DEF')
388388
----
389389
1
@@ -417,7 +417,7 @@ create table test_column_defaults(
417417
e timestamp default now()
418418
) as values(1, 10, 100, 'ABC', now())
419419

420-
query IIITP
420+
query I
421421
insert into test_column_defaults(b) values(20)
422422
----
423423
1

0 commit comments

Comments
 (0)