Skip to content

Commit 1530c5b

Browse files
committed
Add default_cast_for
1 parent 7ed7891 commit 1530c5b

File tree

2 files changed

+160
-9
lines changed

2 files changed

+160
-9
lines changed

datafusion/common/src/types/logical.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use super::NativeType;
19+
use crate::Result;
20+
use arrow_schema::DataType;
1821
use core::fmt;
1922
use std::{cmp::Ordering, hash::Hash, sync::Arc};
2023

21-
use super::NativeType;
22-
2324
/// Signature that uniquely identifies a type among other types.
2425
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
2526
pub enum TypeSignature<'a> {
@@ -75,8 +76,17 @@ pub type LogicalTypeRef = Arc<dyn LogicalType>;
7576
/// }
7677
/// ```
7778
pub trait LogicalType: Sync + Send {
79+
/// Get the native backing type of this logical type.
7880
fn native(&self) -> &NativeType;
81+
/// Get the unique type signature for this logical type. Logical types with identical
82+
/// signatures are considered equal.
7983
fn signature(&self) -> TypeSignature<'_>;
84+
85+
/// Get the default physical type to cast `origin` to in order to obtain a physical type
86+
/// that is logically compatible with this logical type.
87+
fn default_cast_for(&self, origin: &DataType) -> Result<DataType> {
88+
self.native().default_cast_for(origin)
89+
}
8090
}
8191

8292
impl fmt::Debug for dyn LogicalType {
@@ -90,7 +100,7 @@ impl fmt::Debug for dyn LogicalType {
90100

91101
impl PartialEq for dyn LogicalType {
92102
fn eq(&self, other: &Self) -> bool {
93-
self.native().eq(other.native()) && self.signature().eq(&other.signature())
103+
self.signature().eq(&other.signature())
94104
}
95105
}
96106

datafusion/common/src/types/native.rs

Lines changed: 147 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::sync::Arc;
19-
20-
use arrow_schema::{DataType, IntervalUnit, TimeUnit};
21-
2218
use super::{
23-
LogicalFieldRef, LogicalFields, LogicalType, LogicalUnionFields, TypeSignature,
19+
LogicalField, LogicalFieldRef, LogicalFields, LogicalType, LogicalUnionFields,
20+
TypeSignature,
2421
};
22+
use crate::{internal_err, Result};
23+
use arrow_schema::{DataType, Field, FieldRef, IntervalUnit, TimeUnit};
24+
use std::sync::Arc;
2525

2626
/// Representation of a type that DataFusion can handle natively. It is a subset
2727
/// of the physical variants in Arrow's native [`DataType`].
@@ -188,6 +188,147 @@ impl LogicalType for NativeType {
188188
fn signature(&self) -> TypeSignature<'_> {
189189
TypeSignature::Native(self)
190190
}
191+
192+
fn default_cast_for(&self, origin: &DataType) -> Result<DataType> {
193+
use DataType::*;
194+
195+
fn default_field_cast(to: &LogicalField, from: &Field) -> Result<FieldRef> {
196+
Ok(Arc::new(Field::new(
197+
to.name.clone(),
198+
to.logical_type.default_cast_for(from.data_type())?,
199+
to.nullable,
200+
)))
201+
}
202+
203+
Ok(match (self, origin) {
204+
(Self::Null, _) => Null,
205+
(Self::Boolean, _) => Boolean,
206+
(Self::Int8, _) => Int8,
207+
(Self::Int16, _) => Int16,
208+
(Self::Int32, _) => Int32,
209+
(Self::Int64, _) => Int64,
210+
(Self::UInt8, _) => UInt8,
211+
(Self::UInt16, _) => UInt16,
212+
(Self::UInt32, _) => UInt32,
213+
(Self::UInt64, _) => UInt64,
214+
(Self::Float16, _) => Float16,
215+
(Self::Float32, _) => Float32,
216+
(Self::Float64, _) => Float64,
217+
(Self::Decimal(p, s), _) if p <= &38 => Decimal128(p.clone(), s.clone()),
218+
(Self::Decimal(p, s), _) => Decimal256(p.clone(), s.clone()),
219+
(Self::Timestamp(tu, tz), _) => Timestamp(tu.clone(), tz.clone()),
220+
(Self::Date, _) => Date32,
221+
(Self::Time(tu), _) => match tu {
222+
TimeUnit::Second | TimeUnit::Millisecond => Time32(tu.clone()),
223+
TimeUnit::Microsecond | TimeUnit::Nanosecond => Time64(tu.clone()),
224+
},
225+
(Self::Duration(tu), _) => Duration(tu.clone()),
226+
(Self::Interval(iu), _) => Interval(iu.clone()),
227+
(Self::Binary, LargeUtf8) => LargeBinary,
228+
(Self::Binary, Utf8View) => BinaryView,
229+
(Self::Binary, _) => Binary,
230+
(Self::FixedSizeBinary(size), _) => FixedSizeBinary(size.clone()),
231+
(Self::Utf8, LargeBinary) => LargeUtf8,
232+
(Self::Utf8, BinaryView) => Utf8View,
233+
(Self::Utf8, _) => Utf8,
234+
(Self::List(to_field), List(from_field) | FixedSizeList(from_field, _)) => {
235+
List(default_field_cast(to_field, from_field)?)
236+
}
237+
(Self::List(to_field), LargeList(from_field)) => {
238+
LargeList(default_field_cast(to_field, from_field)?)
239+
}
240+
(Self::List(to_field), ListView(from_field)) => {
241+
ListView(default_field_cast(to_field, from_field)?)
242+
}
243+
(Self::List(to_field), LargeListView(from_field)) => {
244+
LargeListView(default_field_cast(to_field, from_field)?)
245+
}
246+
// List array where each element is a len 1 list of the origin type
247+
(Self::List(field), _) => List(Arc::new(Field::new(
248+
field.name.clone(),
249+
field.logical_type.default_cast_for(origin)?,
250+
field.nullable,
251+
))),
252+
(
253+
Self::FixedSizeList(to_field, to_size),
254+
FixedSizeList(from_field, from_size),
255+
) if from_size == to_size => {
256+
FixedSizeList(default_field_cast(to_field, from_field)?, to_size.clone())
257+
}
258+
(
259+
Self::FixedSizeList(to_field, size),
260+
List(from_field)
261+
| LargeList(from_field)
262+
| ListView(from_field)
263+
| LargeListView(from_field),
264+
) => FixedSizeList(default_field_cast(to_field, from_field)?, size.clone()),
265+
// FixedSizeList array where each element is a len 1 list of the origin type
266+
(Self::FixedSizeList(field, size), _) => FixedSizeList(
267+
Arc::new(Field::new(
268+
field.name.clone(),
269+
field.logical_type.default_cast_for(origin)?,
270+
field.nullable,
271+
)),
272+
size.clone(),
273+
),
274+
// From https://github.com/apache/arrow-rs/blob/56525efbd5f37b89d1b56aa51709cab9f81bc89e/arrow-cast/src/cast/mod.rs#L189-L196
275+
(Self::Struct(to_fields), Struct(from_fields))
276+
if from_fields.len() == to_fields.len() =>
277+
{
278+
Struct(
279+
from_fields
280+
.iter()
281+
.zip(to_fields.iter())
282+
.map(|(from, to)| default_field_cast(to, from))
283+
.collect()?,
284+
)
285+
}
286+
(Self::Struct(to_fields), Null) => Struct(
287+
to_fields
288+
.iter()
289+
.map(|field| {
290+
Ok(Arc::new(Field::new(
291+
field.name.clone(),
292+
field.logical_type.default_cast_for(&Null)?,
293+
field.nullable,
294+
)))
295+
})
296+
.collect()?,
297+
),
298+
(Self::Map(to_field), Map(from_field, sorted)) => {
299+
Map(default_field_cast(to_field, from_field)?, sorted.clone())
300+
}
301+
(Self::Map(field), Null) => Map(
302+
Arc::new(Field::new(
303+
field.name.clone(),
304+
field.logical_type.default_cast_for(&Null)?,
305+
field.nullable,
306+
)),
307+
false,
308+
),
309+
(Self::Union(to_fields), Union(from_fields, mode))
310+
if from_fields.len() == to_fields.len() =>
311+
{
312+
Union(
313+
from_fields
314+
.iter()
315+
.zip(to_fields.iter())
316+
.map(|((_, from), (i, to))| {
317+
(i.clone(), default_field_cast(to, from))
318+
})
319+
.collect()?,
320+
mode.clone(),
321+
)
322+
}
323+
_ => {
324+
return internal_err!(
325+
"Unavailable default cast for native type {:?} from physical type {:?}",
326+
self,
327+
origin
328+
)
329+
}
330+
})
331+
}
191332
}
192333

193334
// The following From<DataType>, From<Field>, ... implementations are temporary
@@ -230,9 +371,9 @@ impl From<DataType> for NativeType {
230371
DataType::Union(union_fields, _) => {
231372
Union(LogicalUnionFields::from(&union_fields))
232373
}
233-
DataType::Dictionary(_, data_type) => data_type.as_ref().clone().into(),
234374
DataType::Decimal128(p, s) | DataType::Decimal256(p, s) => Decimal(p, s),
235375
DataType::Map(field, _) => Map(Arc::new(field.as_ref().into())),
376+
DataType::Dictionary(_, data_type) => data_type.as_ref().clone().into(),
236377
DataType::RunEndEncoded(_, field) => field.data_type().clone().into(),
237378
}
238379
}

0 commit comments

Comments
 (0)