|
15 | 15 | // specific language governing permissions and limitations
|
16 | 16 | // under the License.
|
17 | 17 |
|
18 |
| -use std::sync::Arc; |
19 |
| - |
20 |
| -use arrow_schema::{DataType, IntervalUnit, TimeUnit}; |
21 |
| - |
22 | 18 | use super::{
|
23 |
| - LogicalFieldRef, LogicalFields, LogicalType, LogicalUnionFields, TypeSignature, |
| 19 | + LogicalField, LogicalFieldRef, LogicalFields, LogicalType, LogicalUnionFields, |
| 20 | + TypeSignature, |
24 | 21 | };
|
| 22 | +use crate::{internal_err, Result}; |
| 23 | +use arrow_schema::{DataType, Field, FieldRef, IntervalUnit, TimeUnit}; |
| 24 | +use std::sync::Arc; |
25 | 25 |
|
26 | 26 | /// Representation of a type that DataFusion can handle natively. It is a subset
|
27 | 27 | /// of the physical variants in Arrow's native [`DataType`].
|
@@ -188,6 +188,147 @@ impl LogicalType for NativeType {
|
188 | 188 | fn signature(&self) -> TypeSignature<'_> {
|
189 | 189 | TypeSignature::Native(self)
|
190 | 190 | }
|
| 191 | + |
| 192 | + fn default_cast_for(&self, origin: &DataType) -> Result<DataType> { |
| 193 | + use DataType::*; |
| 194 | + |
| 195 | + fn default_field_cast(to: &LogicalField, from: &Field) -> Result<FieldRef> { |
| 196 | + Ok(Arc::new(Field::new( |
| 197 | + to.name.clone(), |
| 198 | + to.logical_type.default_cast_for(from.data_type())?, |
| 199 | + to.nullable, |
| 200 | + ))) |
| 201 | + } |
| 202 | + |
| 203 | + Ok(match (self, origin) { |
| 204 | + (Self::Null, _) => Null, |
| 205 | + (Self::Boolean, _) => Boolean, |
| 206 | + (Self::Int8, _) => Int8, |
| 207 | + (Self::Int16, _) => Int16, |
| 208 | + (Self::Int32, _) => Int32, |
| 209 | + (Self::Int64, _) => Int64, |
| 210 | + (Self::UInt8, _) => UInt8, |
| 211 | + (Self::UInt16, _) => UInt16, |
| 212 | + (Self::UInt32, _) => UInt32, |
| 213 | + (Self::UInt64, _) => UInt64, |
| 214 | + (Self::Float16, _) => Float16, |
| 215 | + (Self::Float32, _) => Float32, |
| 216 | + (Self::Float64, _) => Float64, |
| 217 | + (Self::Decimal(p, s), _) if p <= &38 => Decimal128(p.clone(), s.clone()), |
| 218 | + (Self::Decimal(p, s), _) => Decimal256(p.clone(), s.clone()), |
| 219 | + (Self::Timestamp(tu, tz), _) => Timestamp(tu.clone(), tz.clone()), |
| 220 | + (Self::Date, _) => Date32, |
| 221 | + (Self::Time(tu), _) => match tu { |
| 222 | + TimeUnit::Second | TimeUnit::Millisecond => Time32(tu.clone()), |
| 223 | + TimeUnit::Microsecond | TimeUnit::Nanosecond => Time64(tu.clone()), |
| 224 | + }, |
| 225 | + (Self::Duration(tu), _) => Duration(tu.clone()), |
| 226 | + (Self::Interval(iu), _) => Interval(iu.clone()), |
| 227 | + (Self::Binary, LargeUtf8) => LargeBinary, |
| 228 | + (Self::Binary, Utf8View) => BinaryView, |
| 229 | + (Self::Binary, _) => Binary, |
| 230 | + (Self::FixedSizeBinary(size), _) => FixedSizeBinary(size.clone()), |
| 231 | + (Self::Utf8, LargeBinary) => LargeUtf8, |
| 232 | + (Self::Utf8, BinaryView) => Utf8View, |
| 233 | + (Self::Utf8, _) => Utf8, |
| 234 | + (Self::List(to_field), List(from_field) | FixedSizeList(from_field, _)) => { |
| 235 | + List(default_field_cast(to_field, from_field)?) |
| 236 | + } |
| 237 | + (Self::List(to_field), LargeList(from_field)) => { |
| 238 | + LargeList(default_field_cast(to_field, from_field)?) |
| 239 | + } |
| 240 | + (Self::List(to_field), ListView(from_field)) => { |
| 241 | + ListView(default_field_cast(to_field, from_field)?) |
| 242 | + } |
| 243 | + (Self::List(to_field), LargeListView(from_field)) => { |
| 244 | + LargeListView(default_field_cast(to_field, from_field)?) |
| 245 | + } |
| 246 | + // List array where each element is a len 1 list of the origin type |
| 247 | + (Self::List(field), _) => List(Arc::new(Field::new( |
| 248 | + field.name.clone(), |
| 249 | + field.logical_type.default_cast_for(origin)?, |
| 250 | + field.nullable, |
| 251 | + ))), |
| 252 | + ( |
| 253 | + Self::FixedSizeList(to_field, to_size), |
| 254 | + FixedSizeList(from_field, from_size), |
| 255 | + ) if from_size == to_size => { |
| 256 | + FixedSizeList(default_field_cast(to_field, from_field)?, to_size.clone()) |
| 257 | + } |
| 258 | + ( |
| 259 | + Self::FixedSizeList(to_field, size), |
| 260 | + List(from_field) |
| 261 | + | LargeList(from_field) |
| 262 | + | ListView(from_field) |
| 263 | + | LargeListView(from_field), |
| 264 | + ) => FixedSizeList(default_field_cast(to_field, from_field)?, size.clone()), |
| 265 | + // FixedSizeList array where each element is a len 1 list of the origin type |
| 266 | + (Self::FixedSizeList(field, size), _) => FixedSizeList( |
| 267 | + Arc::new(Field::new( |
| 268 | + field.name.clone(), |
| 269 | + field.logical_type.default_cast_for(origin)?, |
| 270 | + field.nullable, |
| 271 | + )), |
| 272 | + size.clone(), |
| 273 | + ), |
| 274 | + // From https://github.com/apache/arrow-rs/blob/56525efbd5f37b89d1b56aa51709cab9f81bc89e/arrow-cast/src/cast/mod.rs#L189-L196 |
| 275 | + (Self::Struct(to_fields), Struct(from_fields)) |
| 276 | + if from_fields.len() == to_fields.len() => |
| 277 | + { |
| 278 | + Struct( |
| 279 | + from_fields |
| 280 | + .iter() |
| 281 | + .zip(to_fields.iter()) |
| 282 | + .map(|(from, to)| default_field_cast(to, from)) |
| 283 | + .collect()?, |
| 284 | + ) |
| 285 | + } |
| 286 | + (Self::Struct(to_fields), Null) => Struct( |
| 287 | + to_fields |
| 288 | + .iter() |
| 289 | + .map(|field| { |
| 290 | + Ok(Arc::new(Field::new( |
| 291 | + field.name.clone(), |
| 292 | + field.logical_type.default_cast_for(&Null)?, |
| 293 | + field.nullable, |
| 294 | + ))) |
| 295 | + }) |
| 296 | + .collect()?, |
| 297 | + ), |
| 298 | + (Self::Map(to_field), Map(from_field, sorted)) => { |
| 299 | + Map(default_field_cast(to_field, from_field)?, sorted.clone()) |
| 300 | + } |
| 301 | + (Self::Map(field), Null) => Map( |
| 302 | + Arc::new(Field::new( |
| 303 | + field.name.clone(), |
| 304 | + field.logical_type.default_cast_for(&Null)?, |
| 305 | + field.nullable, |
| 306 | + )), |
| 307 | + false, |
| 308 | + ), |
| 309 | + (Self::Union(to_fields), Union(from_fields, mode)) |
| 310 | + if from_fields.len() == to_fields.len() => |
| 311 | + { |
| 312 | + Union( |
| 313 | + from_fields |
| 314 | + .iter() |
| 315 | + .zip(to_fields.iter()) |
| 316 | + .map(|((_, from), (i, to))| { |
| 317 | + (i.clone(), default_field_cast(to, from)) |
| 318 | + }) |
| 319 | + .collect()?, |
| 320 | + mode.clone(), |
| 321 | + ) |
| 322 | + } |
| 323 | + _ => { |
| 324 | + return internal_err!( |
| 325 | + "Unavailable default cast for native type {:?} from physical type {:?}", |
| 326 | + self, |
| 327 | + origin |
| 328 | + ) |
| 329 | + } |
| 330 | + }) |
| 331 | + } |
191 | 332 | }
|
192 | 333 |
|
193 | 334 | // The following From<DataType>, From<Field>, ... implementations are temporary
|
@@ -230,9 +371,9 @@ impl From<DataType> for NativeType {
|
230 | 371 | DataType::Union(union_fields, _) => {
|
231 | 372 | Union(LogicalUnionFields::from(&union_fields))
|
232 | 373 | }
|
233 |
| - DataType::Dictionary(_, data_type) => data_type.as_ref().clone().into(), |
234 | 374 | DataType::Decimal128(p, s) | DataType::Decimal256(p, s) => Decimal(p, s),
|
235 | 375 | DataType::Map(field, _) => Map(Arc::new(field.as_ref().into())),
|
| 376 | + DataType::Dictionary(_, data_type) => data_type.as_ref().clone().into(), |
236 | 377 | DataType::RunEndEncoded(_, field) => field.data_type().clone().into(),
|
237 | 378 | }
|
238 | 379 | }
|
|
0 commit comments