Skip to content

Commit 83c8efe

Browse files
andygrovemichalursa
authored andcommitted
ARROW-12335: [Rust] [Ballista] Use latest DataFusion
Updates Ballista to use the most recent DataFusion version. Changes made: - Ballista overrides physical optimizer rules to remove `Repartition` - Added serde support for new `TryCast` expression - Updated DataFrame API usage to use `Vec<_>` instead of `&[_]` - Renamed some timestamp scalar variants - HashJoinExec updated to take new `CollectLeft` argument - Removed hard-coded batch size from serde code for `CsvScanExec` Closes apache#9991 from andygrove/ballista-bump-df-version Authored-by: Andy Grove <[email protected]> Signed-off-by: Krisztián Szűcs <[email protected]>
1 parent 73f92ce commit 83c8efe

File tree

21 files changed

+189
-92
lines changed

21 files changed

+189
-92
lines changed

rust/ballista/.dockerignore

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
rust/**/target

rust/ballista/rust/Cargo.toml

+3-3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,6 @@ members = [
2525
"scheduler",
2626
]
2727

28-
[profile.release]
29-
lto = true
30-
codegen-units = 1
28+
#[profile.release]
29+
#lto = true
30+
#codegen-units = 1

rust/ballista/rust/benchmarks/tpch/Cargo.toml

+7-3
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,13 @@ edition = "2018"
2727
[dependencies]
2828
ballista = { path="../../client" }
2929

30-
arrow = { git = "https://github.com/apache/arrow", rev="46161d2" }
31-
datafusion = { git = "https://github.com/apache/arrow", rev="46161d2" }
32-
parquet = { git = "https://github.com/apache/arrow", rev="46161d2" }
30+
#arrow = { path = "../../../../arrow" }
31+
#datafusion = { path = "../../../../datafusion" }
32+
#parquet = { path = "../../../../parquet" }
33+
34+
arrow = { git = "https://github.com/apache/arrow", rev="fe83dca" }
35+
datafusion = { git = "https://github.com/apache/arrow", rev="fe83dca" }
36+
parquet = { git = "https://github.com/apache/arrow", rev="fe83dca" }
3337

3438

3539
env_logger = "0.8"

rust/ballista/rust/client/Cargo.toml

+6-2
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,9 @@ ballista-core = { path = "../core" }
3030
futures = "0.3"
3131
log = "0.4"
3232
tokio = "1.0"
33-
arrow = { git = "https://github.com/apache/arrow", rev="46161d2" }
34-
datafusion = { git = "https://github.com/apache/arrow", rev="46161d2" }
33+
34+
#arrow = { path = "../../../arrow" }
35+
#datafusion = { path = "../../../datafusion" }
36+
37+
arrow = { git = "https://github.com/apache/arrow", rev="fe83dca" }
38+
datafusion = { git = "https://github.com/apache/arrow", rev="fe83dca" }

rust/ballista/rust/client/src/context.rs

+9-5
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ use ballista_core::{
3636
};
3737

3838
use arrow::datatypes::Schema;
39+
use datafusion::catalog::TableReference;
3940
use datafusion::execution::context::ExecutionContext;
4041
use datafusion::logical_plan::{DFSchema, Expr, LogicalPlan, Partitioning};
4142
use datafusion::physical_plan::csv::CsvReadOptions;
@@ -148,7 +149,10 @@ impl BallistaContext {
148149
for (name, plan) in &state.tables {
149150
let plan = ctx.optimize(plan)?;
150151
let execution_plan = ctx.create_physical_plan(&plan)?;
151-
ctx.register_table(name, Arc::new(DFTableAdapter::new(plan, execution_plan)));
152+
ctx.register_table(
153+
TableReference::Bare { table: name },
154+
Arc::new(DFTableAdapter::new(plan, execution_plan)),
155+
)?;
152156
}
153157
let df = ctx.sql(sql)?;
154158
Ok(BallistaDataFrame::from(self.state.clone(), df))
@@ -267,7 +271,7 @@ impl BallistaDataFrame {
267271
))
268272
}
269273

270-
pub fn select(&self, expr: &[Expr]) -> Result<BallistaDataFrame> {
274+
pub fn select(&self, expr: Vec<Expr>) -> Result<BallistaDataFrame> {
271275
Ok(Self::from(
272276
self.state.clone(),
273277
self.df.select(expr).map_err(BallistaError::from)?,
@@ -283,8 +287,8 @@ impl BallistaDataFrame {
283287

284288
pub fn aggregate(
285289
&self,
286-
group_expr: &[Expr],
287-
aggr_expr: &[Expr],
290+
group_expr: Vec<Expr>,
291+
aggr_expr: Vec<Expr>,
288292
) -> Result<BallistaDataFrame> {
289293
Ok(Self::from(
290294
self.state.clone(),
@@ -301,7 +305,7 @@ impl BallistaDataFrame {
301305
))
302306
}
303307

304-
pub fn sort(&self, expr: &[Expr]) -> Result<BallistaDataFrame> {
308+
pub fn sort(&self, expr: Vec<Expr>) -> Result<BallistaDataFrame> {
305309
Ok(Self::from(
306310
self.state.clone(),
307311
self.df.sort(expr).map_err(BallistaError::from)?,

rust/ballista/rust/core/Cargo.toml

+7-3
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,14 @@ sqlparser = "0.8"
3939
tokio = "1.0"
4040
tonic = "0.4"
4141
uuid = { version = "0.8", features = ["v4"] }
42-
arrow = { git = "https://github.com/apache/arrow", rev="46161d2" }
43-
arrow-flight = { git = "https://github.com/apache/arrow", rev="46161d2" }
44-
datafusion = { git = "https://github.com/apache/arrow", rev="46161d2" }
4542

43+
#arrow = { path = "../../../arrow" }
44+
#arrow-flight = { path = "../../../arrow-flight" }
45+
#datafusion = { path = "../../../datafusion" }
46+
47+
arrow = { git = "https://github.com/apache/arrow", rev="fe83dca" }
48+
arrow-flight = { git = "https://github.com/apache/arrow", rev="fe83dca" }
49+
datafusion = { git = "https://github.com/apache/arrow", rev="fe83dca" }
4650

4751
[dev-dependencies]
4852

rust/ballista/rust/core/proto/ballista.proto

+6
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ message LogicalExprNode {
5959
InListNode in_list = 14;
6060
bool wildcard = 15;
6161
ScalarFunctionNode scalar_function = 16;
62+
TryCastNode try_cast = 17;
6263
}
6364
}
6465

@@ -172,6 +173,11 @@ message CastNode {
172173
ArrowType arrow_type = 2;
173174
}
174175

176+
message TryCastNode {
177+
LogicalExprNode expr = 1;
178+
ArrowType arrow_type = 2;
179+
}
180+
175181
message SortExprNode {
176182
LogicalExprNode expr = 1;
177183
bool asc = 2;

rust/ballista/rust/core/src/datasource.rs

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ impl TableProvider for DFTableAdapter {
5757
_projection: &Option<Vec<usize>>,
5858
_batch_size: usize,
5959
_filters: &[Expr],
60+
_limit: Option<usize>,
6061
) -> DFResult<Arc<dyn ExecutionPlan>> {
6162
Ok(self.plan.clone())
6263
}

rust/ballista/rust/core/src/serde/logical_plan/from_proto.rs

+27-19
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,13 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
5252
match plan {
5353
LogicalPlanType::Projection(projection) => {
5454
let input: LogicalPlan = convert_box_required!(projection.input)?;
55+
let x: Vec<Expr> = projection
56+
.expr
57+
.iter()
58+
.map(|expr| expr.try_into())
59+
.collect::<Result<Vec<_>, _>>()?;
5560
LogicalPlanBuilder::from(&input)
56-
.project(
57-
&projection
58-
.expr
59-
.iter()
60-
.map(|expr| expr.try_into())
61-
.collect::<Result<Vec<_>, _>>()?,
62-
)?
61+
.project(x)?
6362
.build()
6463
.map_err(|e| e.into())
6564
}
@@ -89,7 +88,7 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
8988
.map(|expr| expr.try_into())
9089
.collect::<Result<Vec<_>, _>>()?;
9190
LogicalPlanBuilder::from(&input)
92-
.aggregate(&group_expr, &aggr_expr)?
91+
.aggregate(group_expr, aggr_expr)?
9392
.build()
9493
.map_err(|e| e.into())
9594
}
@@ -148,7 +147,7 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
148147
.map(|expr| expr.try_into())
149148
.collect::<Result<Vec<Expr>, _>>()?;
150149
LogicalPlanBuilder::from(&input)
151-
.sort(&sort_expr)?
150+
.sort(sort_expr)?
152151
.build()
153152
.map_err(|e| e.into())
154153
}
@@ -511,10 +510,10 @@ fn typechecked_scalar_value_conversion(
511510
ScalarValue::Date32(Some(*v))
512511
}
513512
(Value::TimeMicrosecondValue(v), PrimitiveScalarType::TimeMicrosecond) => {
514-
ScalarValue::TimeMicrosecond(Some(*v))
513+
ScalarValue::TimestampMicrosecond(Some(*v))
515514
}
516515
(Value::TimeNanosecondValue(v), PrimitiveScalarType::TimeMicrosecond) => {
517-
ScalarValue::TimeNanosecond(Some(*v))
516+
ScalarValue::TimestampNanosecond(Some(*v))
518517
}
519518
(Value::Utf8Value(v), PrimitiveScalarType::Utf8) => {
520519
ScalarValue::Utf8(Some(v.to_owned()))
@@ -547,10 +546,10 @@ fn typechecked_scalar_value_conversion(
547546
PrimitiveScalarType::LargeUtf8 => ScalarValue::LargeUtf8(None),
548547
PrimitiveScalarType::Date32 => ScalarValue::Date32(None),
549548
PrimitiveScalarType::TimeMicrosecond => {
550-
ScalarValue::TimeMicrosecond(None)
549+
ScalarValue::TimestampMicrosecond(None)
551550
}
552551
PrimitiveScalarType::TimeNanosecond => {
553-
ScalarValue::TimeNanosecond(None)
552+
ScalarValue::TimestampNanosecond(None)
554553
}
555554
PrimitiveScalarType::Null => {
556555
return Err(proto_error(
@@ -610,10 +609,10 @@ impl TryInto<datafusion::scalar::ScalarValue> for &protobuf::scalar_value::Value
610609
ScalarValue::Date32(Some(*v))
611610
}
612611
protobuf::scalar_value::Value::TimeMicrosecondValue(v) => {
613-
ScalarValue::TimeMicrosecond(Some(*v))
612+
ScalarValue::TimestampMicrosecond(Some(*v))
614613
}
615614
protobuf::scalar_value::Value::TimeNanosecondValue(v) => {
616-
ScalarValue::TimeNanosecond(Some(*v))
615+
ScalarValue::TimestampNanosecond(Some(*v))
617616
}
618617
protobuf::scalar_value::Value::ListValue(v) => v.try_into()?,
619618
protobuf::scalar_value::Value::NullListValue(v) => {
@@ -776,10 +775,10 @@ impl TryInto<datafusion::scalar::ScalarValue> for protobuf::PrimitiveScalarType
776775
protobuf::PrimitiveScalarType::LargeUtf8 => ScalarValue::LargeUtf8(None),
777776
protobuf::PrimitiveScalarType::Date32 => ScalarValue::Date32(None),
778777
protobuf::PrimitiveScalarType::TimeMicrosecond => {
779-
ScalarValue::TimeMicrosecond(None)
778+
ScalarValue::TimestampMicrosecond(None)
780779
}
781780
protobuf::PrimitiveScalarType::TimeNanosecond => {
782-
ScalarValue::TimeNanosecond(None)
781+
ScalarValue::TimestampNanosecond(None)
783782
}
784783
})
785784
}
@@ -829,10 +828,10 @@ impl TryInto<datafusion::scalar::ScalarValue> for &protobuf::ScalarValue {
829828
ScalarValue::Date32(Some(*v))
830829
}
831830
protobuf::scalar_value::Value::TimeMicrosecondValue(v) => {
832-
ScalarValue::TimeMicrosecond(Some(*v))
831+
ScalarValue::TimestampMicrosecond(Some(*v))
833832
}
834833
protobuf::scalar_value::Value::TimeNanosecondValue(v) => {
835-
ScalarValue::TimeNanosecond(Some(*v))
834+
ScalarValue::TimestampNanosecond(Some(*v))
836835
}
837836
protobuf::scalar_value::Value::ListValue(scalar_list) => {
838837
let protobuf::ScalarListValue {
@@ -962,6 +961,15 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
962961
let data_type = arrow_type.try_into()?;
963962
Ok(Expr::Cast { expr, data_type })
964963
}
964+
ExprType::TryCast(cast) => {
965+
let expr = Box::new(parse_required_expr(&cast.expr)?);
966+
let arrow_type: &protobuf::ArrowType = cast
967+
.arrow_type
968+
.as_ref()
969+
.ok_or_else(|| proto_error("Protobuf deserialization error: CastNode message missing required field 'arrow_type'"))?;
970+
let data_type = arrow_type.try_into()?;
971+
Ok(Expr::TryCast { expr, data_type })
972+
}
965973
ExprType::Sort(sort) => Ok(Expr::Sort {
966974
expr: Box::new(parse_required_expr(&sort.expr)?),
967975
asc: sort.asc,

rust/ballista/rust/core/src/serde/logical_plan/mod.rs

+14-14
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ mod roundtrip_tests {
8282
CsvReadOptions::new().schema(&schema).has_header(true),
8383
Some(vec![3, 4]),
8484
)
85-
.and_then(|plan| plan.sort(&[col("salary")]))
85+
.and_then(|plan| plan.sort(vec![col("salary")]))
8686
.and_then(|plan| plan.build())
8787
.map_err(BallistaError::DataFusionError)?,
8888
);
@@ -212,8 +212,8 @@ mod roundtrip_tests {
212212
ScalarValue::LargeUtf8(None),
213213
ScalarValue::List(None, DataType::Boolean),
214214
ScalarValue::Date32(None),
215-
ScalarValue::TimeMicrosecond(None),
216-
ScalarValue::TimeNanosecond(None),
215+
ScalarValue::TimestampMicrosecond(None),
216+
ScalarValue::TimestampNanosecond(None),
217217
ScalarValue::Boolean(Some(true)),
218218
ScalarValue::Boolean(Some(false)),
219219
ScalarValue::Float32(Some(1.0)),
@@ -252,11 +252,11 @@ mod roundtrip_tests {
252252
ScalarValue::LargeUtf8(Some(String::from("Test Large utf8"))),
253253
ScalarValue::Date32(Some(0)),
254254
ScalarValue::Date32(Some(i32::MAX)),
255-
ScalarValue::TimeNanosecond(Some(0)),
256-
ScalarValue::TimeNanosecond(Some(i64::MAX)),
257-
ScalarValue::TimeMicrosecond(Some(0)),
258-
ScalarValue::TimeMicrosecond(Some(i64::MAX)),
259-
ScalarValue::TimeMicrosecond(None),
255+
ScalarValue::TimestampNanosecond(Some(0)),
256+
ScalarValue::TimestampNanosecond(Some(i64::MAX)),
257+
ScalarValue::TimestampMicrosecond(Some(0)),
258+
ScalarValue::TimestampMicrosecond(Some(i64::MAX)),
259+
ScalarValue::TimestampMicrosecond(None),
260260
ScalarValue::List(
261261
Some(vec![
262262
ScalarValue::Float32(Some(-213.1)),
@@ -610,8 +610,8 @@ mod roundtrip_tests {
610610
ScalarValue::Utf8(None),
611611
ScalarValue::LargeUtf8(None),
612612
ScalarValue::Date32(None),
613-
ScalarValue::TimeMicrosecond(None),
614-
ScalarValue::TimeNanosecond(None),
613+
ScalarValue::TimestampMicrosecond(None),
614+
ScalarValue::TimestampNanosecond(None),
615615
//ScalarValue::List(None, DataType::Boolean)
616616
];
617617

@@ -679,7 +679,7 @@ mod roundtrip_tests {
679679
CsvReadOptions::new().schema(&schema).has_header(true),
680680
Some(vec![3, 4]),
681681
)
682-
.and_then(|plan| plan.sort(&[col("salary")]))
682+
.and_then(|plan| plan.sort(vec![col("salary")]))
683683
.and_then(|plan| plan.explain(true))
684684
.and_then(|plan| plan.build())
685685
.map_err(BallistaError::DataFusionError)?;
@@ -689,7 +689,7 @@ mod roundtrip_tests {
689689
CsvReadOptions::new().schema(&schema).has_header(true),
690690
Some(vec![3, 4]),
691691
)
692-
.and_then(|plan| plan.sort(&[col("salary")]))
692+
.and_then(|plan| plan.sort(vec![col("salary")]))
693693
.and_then(|plan| plan.explain(false))
694694
.and_then(|plan| plan.build())
695695
.map_err(BallistaError::DataFusionError)?;
@@ -742,7 +742,7 @@ mod roundtrip_tests {
742742
CsvReadOptions::new().schema(&schema).has_header(true),
743743
Some(vec![3, 4]),
744744
)
745-
.and_then(|plan| plan.sort(&[col("salary")]))
745+
.and_then(|plan| plan.sort(vec![col("salary")]))
746746
.and_then(|plan| plan.build())
747747
.map_err(BallistaError::DataFusionError)?;
748748
roundtrip_test!(plan);
@@ -784,7 +784,7 @@ mod roundtrip_tests {
784784
CsvReadOptions::new().schema(&schema).has_header(true),
785785
Some(vec![3, 4]),
786786
)
787-
.and_then(|plan| plan.aggregate(&[col("state")], &[max(col("salary"))]))
787+
.and_then(|plan| plan.aggregate(vec![col("state")], vec![max(col("salary"))]))
788788
.and_then(|plan| plan.build())
789789
.map_err(BallistaError::DataFusionError)?;
790790

rust/ballista/rust/core/src/serde/logical_plan/to_proto.rs

+4-10
Original file line numberDiff line numberDiff line change
@@ -641,12 +641,12 @@ impl TryFrom<&datafusion::scalar::ScalarValue> for protobuf::ScalarValue {
641641
datafusion::scalar::ScalarValue::Date32(val) => {
642642
create_proto_scalar(val, PrimitiveScalarType::Date32, |s| Value::Date32Value(*s))
643643
}
644-
datafusion::scalar::ScalarValue::TimeMicrosecond(val) => {
644+
datafusion::scalar::ScalarValue::TimestampMicrosecond(val) => {
645645
create_proto_scalar(val, PrimitiveScalarType::TimeMicrosecond, |s| {
646646
Value::TimeMicrosecondValue(*s)
647647
})
648648
}
649-
datafusion::scalar::ScalarValue::TimeNanosecond(val) => {
649+
datafusion::scalar::ScalarValue::TimestampNanosecond(val) => {
650650
create_proto_scalar(val, PrimitiveScalarType::TimeNanosecond, |s| {
651651
Value::TimeNanosecondValue(*s)
652652
})
@@ -939,10 +939,7 @@ impl TryInto<protobuf::LogicalPlanNode> for &LogicalPlan {
939939
})
940940
}
941941
LogicalPlan::Extension { .. } => unimplemented!(),
942-
// _ => Err(BallistaError::General(format!(
943-
// "logical plan to_proto {:?}",
944-
// self
945-
// ))),
942+
LogicalPlan::Union { .. } => unimplemented!(),
946943
}
947944
}
948945
}
@@ -1161,10 +1158,7 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
11611158
Expr::Wildcard => Ok(protobuf::LogicalExprNode {
11621159
expr_type: Some(protobuf::logical_expr_node::ExprType::Wildcard(true)),
11631160
}),
1164-
// _ => Err(BallistaError::General(format!(
1165-
// "logical expr to_proto {:?}",
1166-
// self
1167-
// ))),
1161+
Expr::TryCast { .. } => unimplemented!(),
11681162
}
11691163
}
11701164
}

0 commit comments

Comments
 (0)